mirror of https://github.com/vllm-project/vllm.git
Compare commits
1 Commits
Author | SHA1 | Date |
---|---|---|
|
a5dd03c1eb |
|
@ -66,10 +66,10 @@ function cpu_tests() {
|
|||
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynamic_per_token"
|
||||
|
||||
# Run AWQ test
|
||||
# docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
# set -e
|
||||
# VLLM_USE_V1=0 pytest -s -v \
|
||||
# tests/quantization/test_ipex_quant.py"
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
VLLM_USE_V1=0 pytest -s -v \
|
||||
tests/quantization/test_ipex_quant.py"
|
||||
|
||||
# Run chunked-prefill and prefix-cache test
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
|
|
|
@ -26,5 +26,7 @@ docker run \
|
|||
--name "${container_name}" \
|
||||
"${image_name}" \
|
||||
sh -c '
|
||||
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m
|
||||
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m -tp 2
|
||||
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
|
||||
'
|
||||
|
|
|
@ -8,7 +8,7 @@ image:
|
|||
# -- Image tag
|
||||
tag: "latest"
|
||||
# -- Container launch command
|
||||
command: ["vllm", "serve", "/data/", "--served-model-name", "opt-125m", "--enforce-eager", "--dtype", "bfloat16", "--block-size", "16", "--host", "0.0.0.0", "--port", "8000"]
|
||||
command: ["vllm", "serve", "/data/", "--served-model-name", "opt-125m", "--dtype", "float32", "--block-size", "16", "--host", "0.0.0.0", "--port", "8000"]
|
||||
|
||||
# -- Container port
|
||||
containerPort: 8000
|
||||
|
|
|
@ -36,8 +36,7 @@ DEVICE_REGULAR_ATTN_BACKENDS = {
|
|||
DEVICE_MLA_BLOCK_SIZES = {
|
||||
"cuda": [16, 64], # CUDA supports both standard and extended block sizes
|
||||
"hip": [16, 1], # HIP requires special handling for block_size=1
|
||||
# "cpu": [16] # CPU uses fixed block size from test cases
|
||||
"cpu": [] # FIXME(woosuk): Temporarily disable CPU tests
|
||||
"cpu": [16] # CPU uses fixed block size from test cases
|
||||
}
|
||||
|
||||
|
||||
|
@ -82,14 +81,14 @@ def test_env(
|
|||
m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")
|
||||
|
||||
if device == "cpu":
|
||||
if not use_v1:
|
||||
pytest.skip("CPU backend only supports V1")
|
||||
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
CpuPlatform()):
|
||||
backend = get_attn_backend(16, torch.float16, torch.float16,
|
||||
block_size, False)
|
||||
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
|
||||
if use_v1:
|
||||
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
|
||||
else:
|
||||
assert backend.get_name() == "TORCH_SDPA"
|
||||
|
||||
elif device == "hip":
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
|
@ -205,14 +204,12 @@ def test_fp32_fallback(
|
|||
m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
|
||||
|
||||
if device == "cpu":
|
||||
if not use_v1:
|
||||
pytest.skip("CPU backend only supports V1")
|
||||
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
CpuPlatform()):
|
||||
backend = get_attn_backend(16, torch.float32, torch.float32,
|
||||
16, False)
|
||||
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
|
||||
assert (backend.get_name() == "TORCH_SDPA_VLLM_V1"
|
||||
if use_v1 else "TORCH_SDPA")
|
||||
|
||||
elif device == "cuda":
|
||||
with patch("vllm.attention.selector.current_platform",
|
||||
|
|
|
@ -0,0 +1,307 @@
|
|||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
from vllm._ipex_ops import ipex_ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.backends.mla.common import MLACommonImpl, MLACommonState
|
||||
from vllm.attention.backends.torch_sdpa import TorchSDPAMetadata
|
||||
from vllm.utils import make_tensor_with_pad
|
||||
from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder
|
||||
|
||||
|
||||
class CPUMLABackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "CPU_MLA"
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["CPUMLAMetadata"]:
|
||||
return CPUMLAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["CPUMLAMetadataBuilder"]:
|
||||
return CPUMLAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["MLACommonState"]:
|
||||
return MLACommonState
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["CPUMLAImpl"]:
|
||||
return CPUMLAImpl
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int, # assumed to be 1 for MLA
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (num_blocks, block_size, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
ops.copy_blocks_mla(kv_caches, src_to_dists)
|
||||
|
||||
@staticmethod
|
||||
def get_supported_head_sizes() -> List[int]:
|
||||
return [576]
|
||||
|
||||
|
||||
@dataclass
|
||||
class CPUMLAMetadata(TorchSDPAMetadata):
|
||||
# New for MLA
|
||||
# Input positions for rotrary embeddings since for MLA the rotary
|
||||
# position embeddings are applied inside the attention backend
|
||||
input_positions: torch.Tensor = None
|
||||
|
||||
# required by MLACommonImpl
|
||||
is_profile_run: bool = False
|
||||
|
||||
|
||||
class CPUMLAMetadataBuilder(AttentionMetadataBuilder[CPUMLAMetadata]):
|
||||
|
||||
def __init__(self, input_builder: ModelInputForCPUBuilder) -> None:
|
||||
self.chunked_prefill = input_builder.chunked_prefill
|
||||
self.input_builder = input_builder
|
||||
assert not self.chunked_prefill, \
|
||||
"chunked prefill is currently not supported"
|
||||
|
||||
def prepare(self):
|
||||
self.input_data = self.input_builder.input_data
|
||||
|
||||
def build(self, seq_lens, query_lens, cuda_graph_pad_size, batch_size):
|
||||
input_data = self.input_data
|
||||
prefill_seq_lens = seq_lens[0:input_data.num_prefills]
|
||||
prefill_query_lens = query_lens[0:input_data.num_prefills]
|
||||
slot_mapping = torch.tensor(input_data.slot_mapping,
|
||||
dtype=torch.long,
|
||||
device="cpu")
|
||||
|
||||
# metadata for prefill
|
||||
if input_data.num_prefills > 0:
|
||||
query_lens_tensor = torch.tensor(prefill_query_lens,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
kv_lens_tensor = torch.tensor(prefill_seq_lens,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
query_start_loc = torch.zeros(input_data.num_prefills + 1,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
kv_start_loc = torch.zeros(input_data.num_prefills + 1,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
torch.cumsum(query_lens_tensor,
|
||||
dim=0,
|
||||
dtype=torch.int32,
|
||||
out=query_start_loc[1:])
|
||||
torch.cumsum(kv_lens_tensor,
|
||||
dim=0,
|
||||
dtype=torch.int32,
|
||||
out=kv_start_loc[1:])
|
||||
max_query_len = max(prefill_query_lens)
|
||||
max_kv_len = max(prefill_seq_lens)
|
||||
|
||||
# for chunked-prefill
|
||||
if self.chunked_prefill:
|
||||
prefill_block_tables = make_tensor_with_pad(
|
||||
self.input_data.prefill_block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
else:
|
||||
prefill_block_tables = None
|
||||
|
||||
else:
|
||||
query_start_loc = None
|
||||
kv_start_loc = None
|
||||
max_query_len = None
|
||||
max_kv_len = None
|
||||
prefill_block_tables = None
|
||||
|
||||
# metadata for decode
|
||||
if input_data.num_decode_tokens != 0:
|
||||
seq_lens_tensor = torch.tensor(
|
||||
input_data.seq_lens[input_data.num_prefills:],
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.input_data.decode_block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
else:
|
||||
block_tables = torch.tensor([])
|
||||
seq_lens_tensor = torch.tensor(
|
||||
input_data.seq_lens[:input_data.num_prefills],
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
# For multi-modal models
|
||||
placeholder_index_maps = None
|
||||
if len(input_data.multi_modal_inputs_list) != 0:
|
||||
placeholder_index_maps = {
|
||||
modality: placeholder_map.index_map()
|
||||
for modality, placeholder_map in
|
||||
input_data.multi_modal_placeholder_maps.items()
|
||||
}
|
||||
|
||||
return CPUMLAMetadata(
|
||||
chunked_prefill=self.chunked_prefill,
|
||||
seq_lens=prefill_seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
max_kv_len=max_kv_len,
|
||||
prefill_query_start_loc=query_start_loc,
|
||||
kv_start_loc=kv_start_loc,
|
||||
max_decode_seq_len=input_data.max_decode_seq_len,
|
||||
num_prefills=input_data.num_prefills,
|
||||
num_prefill_tokens=input_data.num_prefill_tokens,
|
||||
num_decode_tokens=input_data.num_decode_tokens,
|
||||
block_tables=block_tables,
|
||||
prefill_block_tables=prefill_block_tables,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
enable_kv_scales_calculation=False,
|
||||
input_positions=torch.tensor([self.input_data.input_positions]))
|
||||
|
||||
|
||||
class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]],
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
# MLA Specific Arguments
|
||||
**mla_args) -> None:
|
||||
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
blocksparse_params, logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **mla_args)
|
||||
|
||||
unsupported_features = [
|
||||
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||
]
|
||||
if any(unsupported_features):
|
||||
raise NotImplementedError(
|
||||
"CPUMLAImpl does not support one of the following: "
|
||||
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||
"logits_soft_cap")
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"CPUMLAImpl")
|
||||
|
||||
# states is implemented.
|
||||
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"CPUMLAImpl with FP8 KV cache not yet supported")
|
||||
|
||||
def _forward_prefill(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
kv_c_normed: torch.Tensor,
|
||||
k_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: CPUMLAMetadata, # type: ignore[override]
|
||||
) -> torch.Tensor:
|
||||
|
||||
prefill_metadata = attn_metadata.prefill_metadata
|
||||
assert prefill_metadata is not None
|
||||
|
||||
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
|
||||
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
k_nope, v = kv_nope\
|
||||
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
||||
|
||||
# For MLA the v head dim is smaller than qk head dim so we pad out
|
||||
# v with 0s to match the qk head dim
|
||||
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
|
||||
value=0)
|
||||
|
||||
output = torch.empty_like(q)
|
||||
ipex_ops.varlen_attention(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v_padded,
|
||||
out=output,
|
||||
seqlen_q=prefill_metadata.prefill_query_start_loc,
|
||||
seqlen_k=prefill_metadata.prefill_query_start_loc,
|
||||
max_seqlen_q=prefill_metadata.max_query_len,
|
||||
max_seqlen_k=prefill_metadata.max_query_len,
|
||||
pdropout=0.0,
|
||||
softmax_scale=self.scale,
|
||||
zero_tensors=False,
|
||||
is_causal=True,
|
||||
return_softmax=False,
|
||||
gen_=None,
|
||||
logits_soft_cap=0.0,
|
||||
window_size_left=-1,
|
||||
window_size_right=-1,
|
||||
alibi_slopes=None,
|
||||
)
|
||||
|
||||
# remove padding
|
||||
output = output.view(-1, self.num_heads,
|
||||
q.shape[-1])[..., :v.shape[-1]]
|
||||
return output.reshape(-1, self.num_heads * v.shape[-1])
|
||||
|
||||
def _forward_decode(
|
||||
self,
|
||||
q_nope: torch.Tensor,
|
||||
q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
attn_metadata: CPUMLAMetadata, # type: ignore[override]
|
||||
) -> torch.Tensor:
|
||||
assert kv_c_and_k_pe_cache.numel() > 0
|
||||
|
||||
decode_meta = attn_metadata.decode_metadata
|
||||
assert decode_meta is not None
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
o = q.new_empty(q.shape[0], self.num_heads, self.kv_lora_rank)
|
||||
|
||||
# Run MQA
|
||||
ops.mla_decode_kvcache_cpu(o, q, kv_c_and_k_pe_cache, self.scale,
|
||||
decode_meta.block_tables,
|
||||
decode_meta.seq_lens_tensor)
|
||||
return self._v_up_proj(o)
|
|
@ -0,0 +1,403 @@
|
|||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
""" Attention layer with torch scaled_dot_product_attention
|
||||
and PagedAttention."""
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
|
||||
from vllm._ipex_ops import ipex_ops
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata, AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||
PagedAttentionMetadata)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_PARTITION_SIZE = 512
|
||||
|
||||
|
||||
class IpexAttnBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "IPEX"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["IpexAttnBackendImpl"]:
|
||||
return IpexAttnBackendImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["IpexAttnMetadata"]:
|
||||
return IpexAttnMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||
num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
||||
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
||||
ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||
|
||||
|
||||
@dataclass
|
||||
class IpexAttnMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
"""Metadata for IpexAttnBackend.
|
||||
"""
|
||||
# Currently, input sequences can only contain all prompts
|
||||
# or all decoding. True if all sequences are prompts.
|
||||
is_prompt: bool
|
||||
slot_mapping: torch.Tensor
|
||||
seq_lens: Optional[List[int]]
|
||||
seqlen_q: Optional[torch.Tensor]
|
||||
max_seqlen: Optional[int]
|
||||
|
||||
def __post_init__(self):
|
||||
# Set during the execution of the first attention op.
|
||||
# It is a list because it is needed to set per prompt
|
||||
# when alibi slopes is used. It is because of the limitation
|
||||
# from xformer API.
|
||||
# will not appear in the __repr__ and __init__
|
||||
self.attn_bias: Optional[List[torch.Tensor]] = None
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["IpexAttnMetadata"]:
|
||||
# Currently chunked prefill is not supported
|
||||
if self.num_decode_tokens == 0:
|
||||
assert self.num_prefills > 0
|
||||
return self
|
||||
|
||||
return None
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["IpexAttnMetadata"]:
|
||||
# Currently chunked prefill is not supported
|
||||
if self.num_prefills > 0:
|
||||
assert self.num_decode_tokens == 0
|
||||
return None
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
raise NotImplementedError("KV sharing is not supported in V0.")
|
||||
if use_irope:
|
||||
logger.warning_once(
|
||||
"Using irope in Ipex is not supported yet, it will fall"
|
||||
" back to global attention for long context.")
|
||||
if blocksparse_params is not None:
|
||||
raise ValueError(
|
||||
"IPEX backend does not support block-sparse attention.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
if alibi_slopes is not None:
|
||||
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||
self.alibi_slopes = alibi_slopes
|
||||
self.sliding_window = sliding_window
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.need_mask = (self.sliding_window is not None)
|
||||
if logits_soft_cap is None:
|
||||
logits_soft_cap = -1
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
|
||||
supported_head_sizes = PagedAttention.get_supported_head_sizes()
|
||||
if head_size not in supported_head_sizes:
|
||||
raise ValueError(
|
||||
f"Head size {head_size} is not supported by PagedAttention. "
|
||||
f"Supported head sizes are: {supported_head_sizes}.")
|
||||
if is_quantized_kv_cache(kv_cache_dtype):
|
||||
raise NotImplementedError(
|
||||
"IPEX backend does not support FP8 KV cache. "
|
||||
"Please use xFormers backend instead.")
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"IpexAttnBackendImpl")
|
||||
|
||||
def split_kv_cache(
|
||||
self,
|
||||
kv_cache: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
x = 1
|
||||
num_blocks = kv_cache.shape[1]
|
||||
|
||||
key_cache = kv_cache[0]
|
||||
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
|
||||
-1, x)
|
||||
value_cache = kv_cache[1]
|
||||
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
|
||||
return key_cache, value_cache
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: IpexAttnMetadata, # type: ignore
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with IPEX varlen_attention and PagedAttention.
|
||||
|
||||
Args:
|
||||
query: shape = [num_tokens, num_heads * head_size]
|
||||
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||
for profiling run.
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [num_tokens, num_heads * head_size]
|
||||
"""
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for IpexAttentionImpl")
|
||||
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||
num_tokens, hidden_size = query.shape
|
||||
# Reshape the query, key, and value tensors.
|
||||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||
|
||||
if kv_cache.numel() > 0:
|
||||
key_cache, value_cache = self.split_kv_cache(
|
||||
kv_cache, self.num_kv_heads, self.head_size)
|
||||
ipex_ops.reshape_and_cache(
|
||||
key,
|
||||
value,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.slot_mapping.flatten(),
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale_float,
|
||||
layer._v_scale_float,
|
||||
)
|
||||
|
||||
if attn_metadata.is_prompt:
|
||||
assert attn_metadata.seq_lens is not None
|
||||
if (kv_cache.numel() == 0
|
||||
or attn_metadata.block_tables.numel() == 0):
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
|
||||
value = value.repeat_interleave(self.num_queries_per_kv,
|
||||
dim=1)
|
||||
|
||||
if attn_metadata.attn_bias is None:
|
||||
if self.sliding_window is not None:
|
||||
att_masks = _make_sliding_window_bias(
|
||||
attn_metadata.seq_lens, self.sliding_window,
|
||||
query.dtype) # type: ignore
|
||||
else:
|
||||
att_masks = _make_sliding_window_bias(
|
||||
attn_metadata.seq_lens, None, dtype=query.dtype)
|
||||
attn_metadata.attn_bias = att_masks
|
||||
|
||||
output = torch.empty(
|
||||
(num_tokens, self.num_heads, self.head_size),
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
ipex_ops.varlen_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
output,
|
||||
attn_metadata.seqlen_q,
|
||||
attn_metadata.seqlen_q,
|
||||
self.alibi_slopes,
|
||||
attn_metadata.max_seqlen,
|
||||
attn_metadata.max_seqlen,
|
||||
pdropout=0.0,
|
||||
softmax_scale=self.scale,
|
||||
zero_tensors=False,
|
||||
is_causal=True,
|
||||
return_softmax=False,
|
||||
gen_=None,
|
||||
window_size_left=-1,
|
||||
window_size_right=-1,
|
||||
logits_soft_cap=self.logits_soft_cap,
|
||||
)
|
||||
else:
|
||||
# prefix-enabled attention
|
||||
raise RuntimeError(
|
||||
"IPEX backend doesn't support prefix decoding.")
|
||||
|
||||
else:
|
||||
# Decoding run.
|
||||
max_seq_len = attn_metadata.max_decode_seq_len
|
||||
output = torch.empty_like(query)
|
||||
block_size = value_cache.shape[3]
|
||||
num_seqs, num_heads, head_size = query.shape
|
||||
max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
|
||||
_PARTITION_SIZE)
|
||||
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||
# sequences or heads is large, we use V1 since there is enough work
|
||||
# to parallelize.
|
||||
# TODO(woosuk): Tune this heuristic.
|
||||
# For context len > 8192, use V2 kernel to avoid shared memory
|
||||
# shortage.
|
||||
use_v1 = (max_seq_len <= 8192 and
|
||||
(max_num_partitions == 1 or num_seqs * num_heads > 512))
|
||||
if use_v1:
|
||||
# Run PagedAttention V1.
|
||||
ipex_ops.paged_attention_v1(
|
||||
output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.seq_lens_tensor,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
self.alibi_slopes,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale_float,
|
||||
layer._v_scale_float,
|
||||
)
|
||||
else:
|
||||
# Run PagedAttention V2.
|
||||
assert _PARTITION_SIZE % block_size == 0
|
||||
tmp_output = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||
dtype=output.dtype,
|
||||
device=output.device,
|
||||
)
|
||||
exp_sums = torch.empty(
|
||||
size=(num_seqs, num_heads, max_num_partitions),
|
||||
dtype=torch.float32,
|
||||
device=output.device,
|
||||
)
|
||||
max_logits = torch.empty_like(exp_sums)
|
||||
ipex_ops.paged_attention_v2(
|
||||
output,
|
||||
exp_sums,
|
||||
max_logits,
|
||||
tmp_output,
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.seq_lens_tensor,
|
||||
block_size,
|
||||
max_seq_len,
|
||||
self.alibi_slopes,
|
||||
self.kv_cache_dtype,
|
||||
layer._k_scale_float,
|
||||
layer._v_scale_float,
|
||||
)
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.view(-1, self.num_heads * self.head_size)
|
||||
|
||||
|
||||
def _make_alibi_bias(
|
||||
alibi_slopes: torch.Tensor,
|
||||
dtype: torch.dtype,
|
||||
seq_lens: List[int],
|
||||
) -> List[torch.Tensor]:
|
||||
attn_biases = []
|
||||
for seq_len in seq_lens:
|
||||
bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device)
|
||||
# NOTE(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(seq_len, 1)`
|
||||
# here. We find that both biases give the same results, but
|
||||
# the bias below more accurately follows the original ALiBi
|
||||
# paper.
|
||||
bias = bias[None, :] - bias[:, None]
|
||||
|
||||
num_heads = alibi_slopes.shape[0]
|
||||
bias = bias[None, :].repeat((num_heads, 1, 1))
|
||||
bias.mul_(alibi_slopes[:, None, None])
|
||||
inf_mask = torch.empty(
|
||||
(1, seq_len, seq_len),
|
||||
dtype=bias.dtype,
|
||||
device=alibi_slopes.device).fill_(-torch.inf).triu_(diagonal=1)
|
||||
attn_biases.append((bias + inf_mask).to(dtype))
|
||||
|
||||
return attn_biases
|
||||
|
||||
|
||||
def _make_sliding_window_bias(
|
||||
seq_lens: List[int],
|
||||
window_size: Optional[int],
|
||||
dtype: torch.dtype,
|
||||
) -> List[torch.Tensor]:
|
||||
attn_biases = []
|
||||
for seq_len in seq_lens:
|
||||
tensor = torch.full(
|
||||
(1, seq_len, seq_len),
|
||||
dtype=dtype,
|
||||
fill_value=1,
|
||||
)
|
||||
shift = 0
|
||||
mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore
|
||||
if window_size is not None:
|
||||
mask = torch.triu(mask, diagonal=shift - window_size + 1)
|
||||
mask = torch.log(mask)
|
||||
attn_biases.append(mask.to(dtype))
|
||||
|
||||
return attn_biases
|
|
@ -0,0 +1,356 @@
|
|||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch_xla.experimental.custom_kernel # Required to register custom ops.
|
||||
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata, AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class PallasAttentionBackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "PALLAS"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["PallasAttentionBackendImpl"]:
|
||||
return PallasAttentionBackendImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["PallasMetadata"]:
|
||||
return PallasMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return (num_kv_heads, num_blocks, block_size, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
raise RuntimeError("swap_blocks is not used for the TPU backend.")
|
||||
|
||||
@torch.compile(backend="openxla")
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
src_to_dists: Tuple[torch.Tensor, torch.Tensor],
|
||||
) -> None:
|
||||
src_indices, dst_indices = src_to_dists
|
||||
for k_cache, v_cache in kv_caches:
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True)
|
||||
k_cache[:, dst_indices] = k_cache[:, src_indices]
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
|
||||
v_cache[:, dst_indices] = v_cache[:, src_indices]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PallasMetadata(AttentionMetadata):
|
||||
|
||||
# Currently, input sequences can only contain all prefills
|
||||
# or all decoding.
|
||||
block_tables: Optional[torch.Tensor] = None
|
||||
context_lens: Optional[torch.Tensor] = None
|
||||
effective_query_lens: Optional[torch.Tensor] = None
|
||||
|
||||
@property
|
||||
def prefill_metadata(self) -> Optional["PallasMetadata"]:
|
||||
if self.num_prefills == 0:
|
||||
return None
|
||||
|
||||
assert self.num_decode_tokens == 0
|
||||
return self
|
||||
|
||||
@property
|
||||
def decode_metadata(self) -> Optional["PallasMetadata"]:
|
||||
if self.num_decode_tokens == 0:
|
||||
return None
|
||||
|
||||
assert self.num_prefills == 0
|
||||
assert self.num_prefill_tokens == 0
|
||||
assert self.block_tables is not None
|
||||
assert self.context_lens is not None
|
||||
return self
|
||||
|
||||
|
||||
class PallasAttentionBackendImpl(AttentionImpl):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[List[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
use_irope: bool = False,
|
||||
) -> None:
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
raise NotImplementedError("KV sharing is not supported in V0.")
|
||||
if use_irope:
|
||||
logger.warning_once(
|
||||
"Using irope in Pallas is not supported yet, it will fall back "
|
||||
"to global attention for long context.")
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.logits_soft_cap = logits_soft_cap
|
||||
if head_size % 128 != 0:
|
||||
raise NotImplementedError(
|
||||
f"Head size must be a multiple of 128, found {head_size}.")
|
||||
if alibi_slopes is not None:
|
||||
raise NotImplementedError("Alibi slopes is not supported.")
|
||||
if sliding_window is not None:
|
||||
raise NotImplementedError("Sliding window is not supported.")
|
||||
if is_quantized_kv_cache(kv_cache_dtype):
|
||||
raise NotImplementedError("FP8 KV cache dtype is not supported.")
|
||||
if blocksparse_params is not None:
|
||||
raise NotImplementedError("Blocksparse is not supported.")
|
||||
|
||||
if torch_xla.tpu.version() < 4:
|
||||
raise NotImplementedError("TPU version must be 4 or higher.")
|
||||
|
||||
self.megacore_mode = None
|
||||
tpu_env = torch_xla.tpu.get_tpu_env()
|
||||
tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None)
|
||||
or tpu_env.get("TYPE", None)
|
||||
or tpu_env.get("TPU_ACCELERATOR_TYPE", None))
|
||||
assert tpu_type is not None
|
||||
tpu_type = tpu_type.lower()
|
||||
|
||||
if (("lite" not in tpu_type) and ("v6" not in tpu_type)):
|
||||
if self.num_kv_heads % 2 == 0:
|
||||
self.megacore_mode = "kv_head"
|
||||
else:
|
||||
# NOTE(woosuk): If the batch size is not a multiple of 2, the
|
||||
# megacore mode will be None.
|
||||
self.megacore_mode = "batch"
|
||||
|
||||
if attn_type != AttentionType.DECODER:
|
||||
raise NotImplementedError("Encoder self-attention and "
|
||||
"encoder/decoder cross-attention "
|
||||
"are not implemented for "
|
||||
"PallasAttentionBackendImpl")
|
||||
|
||||
def forward(
|
||||
self,
|
||||
layer: AttentionLayer,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||
attn_metadata: PallasMetadata,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
output_scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass with Pallas attention.
|
||||
|
||||
Args:
|
||||
query: shape = [batch_size, seq_len, num_heads * head_size]
|
||||
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size]
|
||||
kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size]
|
||||
NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor
|
||||
with shape [0] for profiling run.
|
||||
attn_metadata: Metadata for attention.
|
||||
Returns:
|
||||
shape = [batch_size, seq_len, num_heads * head_size]
|
||||
"""
|
||||
if output_scale is not None:
|
||||
raise NotImplementedError(
|
||||
"fused output quantization is not yet supported"
|
||||
" for PallasAttentionImpl")
|
||||
|
||||
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||
batch_size, seq_len, hidden_size = query.shape
|
||||
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
|
||||
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
|
||||
value = value.view(batch_size, seq_len, self.num_kv_heads,
|
||||
self.head_size)
|
||||
|
||||
if kv_cache[0].numel() > 0:
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
key_cache, value_cache = kv_cache
|
||||
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
|
||||
|
||||
query = query * self.scale
|
||||
if attn_metadata.num_prefills > 0:
|
||||
if attn_metadata.block_tables is None:
|
||||
# Prefill without paged KV cache.
|
||||
assert seq_len % 16 == 0, (
|
||||
"Pallas FlashAttention kernel requires seq_len to be a "
|
||||
f"multiple of 16 but got {seq_len}")
|
||||
|
||||
# Handle GQA/MQA.
|
||||
if self.num_kv_heads != self.num_heads:
|
||||
key = key.repeat_interleave(self.num_queries_per_kv,
|
||||
dim=-2)
|
||||
key = key.view(batch_size, seq_len, self.num_heads,
|
||||
self.head_size)
|
||||
value = value.repeat_interleave(self.num_queries_per_kv,
|
||||
dim=-2)
|
||||
value = value.view(batch_size, seq_len, self.num_heads,
|
||||
self.head_size)
|
||||
# FlashAttention kernel requires the input shape to be
|
||||
# [batch_size, num_heads, seq_len, d_model]
|
||||
# while the input is [batch_size, seq_len, num_heads, d_model].
|
||||
# Permute the input to match the required format.
|
||||
output = torch.ops.xla.flash_attention(
|
||||
query.permute(0, 2, 1, 3),
|
||||
key.permute(0, 2, 1, 3),
|
||||
value.permute(0, 2, 1, 3),
|
||||
True,
|
||||
)
|
||||
output = output.permute(0, 2, 1, 3)
|
||||
else:
|
||||
# Prefill with paged KV cache.
|
||||
# TODO(woosuk): Tune the below knobs.
|
||||
num_kv_pages_per_compute_block = 16
|
||||
num_queries_per_compute_block = 16
|
||||
assert seq_len % num_queries_per_compute_block == 0
|
||||
output = torch.ops.xla.multi_queries_paged_attention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.block_tables,
|
||||
attn_metadata.effective_query_lens,
|
||||
num_kv_pages_per_compute_block,
|
||||
num_queries_per_compute_block,
|
||||
use_kernel=True,
|
||||
attn_logits_soft_cap=self.logits_soft_cap,
|
||||
)
|
||||
else:
|
||||
# Decoding run.
|
||||
assert kv_cache[0].numel() > 0
|
||||
query = query.squeeze(dim=1)
|
||||
pages_per_compute_block = 16 # TODO(woosuk): Tune this value.
|
||||
|
||||
assert attn_metadata.block_tables is not None
|
||||
assert attn_metadata.context_lens is not None
|
||||
# NOTE(woosuk): The PagedAttention Pallas kernel stores the entire
|
||||
# block table in SMEM. Therefore, if the block table is too large,
|
||||
# the kernel compilation will fail. To avoid this, we split the
|
||||
# batch dimension into smaller chunks and run the kernel multiple
|
||||
# times.
|
||||
MAX_SMEM_USAGE = 512 * 1024
|
||||
size_per_seq = 4 * attn_metadata.block_tables.shape[1]
|
||||
max_num_seq = MAX_SMEM_USAGE // size_per_seq
|
||||
|
||||
if batch_size <= max_num_seq:
|
||||
output = paged_attention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.context_lens,
|
||||
attn_metadata.block_tables,
|
||||
pages_per_compute_block,
|
||||
self.megacore_mode,
|
||||
attn_logits_soft_cap=self.logits_soft_cap,
|
||||
)
|
||||
else:
|
||||
chunk_size = max_num_seq
|
||||
# Make sure the chunk size is a multiple of 2.
|
||||
chunk_size = chunk_size // 2 * 2
|
||||
num_chunks = (batch_size + chunk_size - 1) // chunk_size
|
||||
|
||||
output = torch.empty_like(query)
|
||||
for chunk_idx in range(num_chunks):
|
||||
chunk_start = chunk_idx * chunk_size
|
||||
chunk_end = chunk_start + chunk_size
|
||||
# NOTE(woosuk): We skip this line because it causes Dynamo
|
||||
# compilation error. Instead, we rely on the slice operation
|
||||
# to handle the out-of-bound case.
|
||||
# chunk_end = min(chunk_end, batch_size)
|
||||
chunk_output = paged_attention(
|
||||
query[chunk_start:chunk_end],
|
||||
key_cache,
|
||||
value_cache,
|
||||
attn_metadata.context_lens[chunk_start:chunk_end],
|
||||
attn_metadata.block_tables[chunk_start:chunk_end],
|
||||
pages_per_compute_block,
|
||||
self.megacore_mode,
|
||||
attn_logits_soft_cap=self.logits_soft_cap,
|
||||
)
|
||||
output[chunk_start:chunk_end] = chunk_output
|
||||
|
||||
# Reshape the output tensor.
|
||||
return output.reshape(batch_size, seq_len, hidden_size)
|
||||
|
||||
|
||||
def write_to_kv_cache(
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
slot_mapping: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)
|
||||
|
||||
key = key.flatten(0, 2)
|
||||
value = value.flatten(0, 2)
|
||||
key_cache = key_cache.flatten(0, 2)
|
||||
value_cache = value_cache.flatten(0, 2)
|
||||
key_cache.index_copy_(0, slot_mapping, key)
|
||||
value_cache.index_copy_(0, slot_mapping, value)
|
||||
|
||||
|
||||
def paged_attention(
|
||||
query: torch.Tensor,
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
context_lens: torch.Tensor,
|
||||
block_tables: torch.Tensor,
|
||||
pages_per_compute_block: int,
|
||||
megacore_mode: Optional[str],
|
||||
*,
|
||||
attn_logits_soft_cap: Optional[float],
|
||||
) -> torch.Tensor:
|
||||
batch_size = query.shape[0]
|
||||
if megacore_mode == "batch" and batch_size % 2 != 0:
|
||||
megacore_mode = None
|
||||
else:
|
||||
megacore_mode = megacore_mode
|
||||
|
||||
return torch.ops.xla.paged_attention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
context_lens,
|
||||
block_tables,
|
||||
pages_per_compute_block,
|
||||
megacore_mode=megacore_mode,
|
||||
attn_logits_soft_cap=attn_logits_soft_cap,
|
||||
)
|
|
@ -3,24 +3,78 @@
|
|||
""" Attention layer with torch scaled_dot_product_attention
|
||||
and PagedAttention."""
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
import torch
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer,
|
||||
AttentionMetadata, AttentionType,
|
||||
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||
AttentionLayer,
|
||||
AttentionMetadata,
|
||||
AttentionMetadataBuilder,
|
||||
AttentionType,
|
||||
is_quantized_kv_cache)
|
||||
# yapf: enable
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.attention.ops.ipex_attn import PagedAttention, _use_ipex
|
||||
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import make_tensor_with_pad
|
||||
from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TorchSDPABackend(AttentionBackend):
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "TORCH_SDPA"
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
|
||||
return TorchSDPABackendImpl
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||
return TorchSDPAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||
return CommonAttentionState
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls() -> Type["TorchSDPAMetadataBuilder"]:
|
||||
return TorchSDPAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(
|
||||
num_blocks: int,
|
||||
block_size: int,
|
||||
num_kv_heads: int,
|
||||
head_size: int,
|
||||
) -> Tuple[int, ...]:
|
||||
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||
num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def swap_blocks(
|
||||
src_kv_cache: torch.Tensor,
|
||||
dst_kv_cache: torch.Tensor,
|
||||
src_to_dst: torch.Tensor,
|
||||
) -> None:
|
||||
raise NotImplementedError("Swap is not supported in TorchSDPABackend.")
|
||||
|
||||
@staticmethod
|
||||
def copy_blocks(
|
||||
kv_caches: List[torch.Tensor],
|
||||
src_to_dists: torch.Tensor,
|
||||
) -> None:
|
||||
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||
"""Metadata for TorchSDPABackend.
|
||||
|
@ -233,6 +287,113 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
|
|||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||
|
||||
|
||||
class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]):
|
||||
|
||||
def __init__(self, input_builder: ModelInputForCPUBuilder) -> None:
|
||||
self.chunked_prefill = input_builder.chunked_prefill
|
||||
self.input_builder = input_builder
|
||||
|
||||
def prepare(self):
|
||||
self.input_data = self.input_builder.input_data
|
||||
|
||||
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||
cuda_graph_pad_size: int, batch_size: int) -> TorchSDPAMetadata:
|
||||
input_data = self.input_data
|
||||
prefill_seq_lens = seq_lens[0:input_data.num_prefills]
|
||||
prefill_query_lens = query_lens[0:input_data.num_prefills]
|
||||
slot_mapping = torch.tensor(input_data.slot_mapping,
|
||||
dtype=torch.long,
|
||||
device="cpu")
|
||||
|
||||
# For chunked-prefill
|
||||
if self.chunked_prefill and input_data.num_prefill_tokens != 0:
|
||||
prefill_block_tables = make_tensor_with_pad(
|
||||
self.input_data.prefill_block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
query_lens_tensor = torch.tensor(prefill_query_lens,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
kv_lens_tensor = torch.tensor(prefill_seq_lens,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
query_start_loc = torch.zeros(input_data.num_prefills + 1,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
kv_start_loc = torch.zeros(input_data.num_prefills + 1,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
torch.cumsum(query_lens_tensor,
|
||||
dim=0,
|
||||
dtype=torch.int32,
|
||||
out=query_start_loc[1:])
|
||||
torch.cumsum(kv_lens_tensor,
|
||||
dim=0,
|
||||
dtype=torch.int32,
|
||||
out=kv_start_loc[1:])
|
||||
max_query_len = max(prefill_query_lens)
|
||||
max_kv_len = max(prefill_seq_lens)
|
||||
else:
|
||||
prefill_block_tables = None
|
||||
query_start_loc = None
|
||||
kv_start_loc = None
|
||||
max_query_len = None
|
||||
max_kv_len = None
|
||||
|
||||
# For paged attention
|
||||
if input_data.num_decode_tokens != 0:
|
||||
seq_lens_tensor = torch.tensor(
|
||||
input_data.seq_lens[input_data.num_prefills:],
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
block_tables = make_tensor_with_pad(
|
||||
self.input_data.decode_block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
else:
|
||||
block_tables = torch.tensor([])
|
||||
seq_lens_tensor = torch.tensor(
|
||||
input_data.seq_lens[:input_data.num_prefills],
|
||||
dtype=torch.int32,
|
||||
device="cpu",
|
||||
)
|
||||
|
||||
# For multi-modal models
|
||||
placeholder_index_maps = None
|
||||
if len(input_data.multi_modal_inputs_list) != 0:
|
||||
placeholder_index_maps = {
|
||||
modality: placeholder_map.index_map()
|
||||
for modality, placeholder_map in
|
||||
input_data.multi_modal_placeholder_maps.items()
|
||||
}
|
||||
|
||||
attn_metadata = TorchSDPAMetadata(
|
||||
chunked_prefill=self.chunked_prefill,
|
||||
seq_lens=prefill_seq_lens,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_query_len=max_query_len,
|
||||
max_kv_len=max_kv_len,
|
||||
prefill_query_start_loc=query_start_loc,
|
||||
kv_start_loc=kv_start_loc,
|
||||
max_decode_seq_len=input_data.max_decode_seq_len,
|
||||
num_prefills=input_data.num_prefills,
|
||||
num_prefill_tokens=input_data.num_prefill_tokens,
|
||||
num_decode_tokens=input_data.num_decode_tokens,
|
||||
block_tables=block_tables,
|
||||
prefill_block_tables=prefill_block_tables,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
enable_kv_scales_calculation=False,
|
||||
)
|
||||
|
||||
return attn_metadata
|
||||
|
||||
|
||||
class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||
|
||||
def __init__(
|
||||
|
|
|
@ -64,11 +64,13 @@ class CpuPlatform(Platform):
|
|||
if selected_backend and selected_backend != _Backend.TORCH_SDPA:
|
||||
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
||||
if use_mla:
|
||||
raise NotImplementedError("MLA is not supported on CPU.")
|
||||
logger.info("Using CPU MLA backend.")
|
||||
return "vllm.attention.backends.cpu_mla.CPUMLABackend"
|
||||
logger.info("Using Torch SDPA backend.")
|
||||
if not use_v1:
|
||||
raise ValueError("CPU backend only supports V1.")
|
||||
return "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend"
|
||||
if use_v1:
|
||||
return "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend"
|
||||
else:
|
||||
return "vllm.attention.backends.torch_sdpa.TorchSDPABackend"
|
||||
|
||||
@classmethod
|
||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||
|
@ -145,14 +147,26 @@ class CpuPlatform(Platform):
|
|||
parallel_config.distributed_executor_backend)
|
||||
parallel_config.distributed_executor_backend = "mp"
|
||||
if parallel_config.worker_cls == "auto":
|
||||
parallel_config.worker_cls = "vllm.v1.worker.cpu_worker.CPUWorker"
|
||||
if vllm_config.speculative_config:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
|
||||
parallel_config.sd_worker_cls = \
|
||||
"vllm.worker.cpu_worker.CPUWorker"
|
||||
else:
|
||||
if envs.VLLM_USE_V1:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.v1.worker.cpu_worker.CPUWorker"
|
||||
else:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.worker.cpu_worker.CPUWorker"
|
||||
|
||||
# Note: workaround for v1 gpu_model_runner
|
||||
from vllm.config import CompilationLevel
|
||||
vllm_config.compilation_config.cudagraph_capture_sizes = []
|
||||
|
||||
compilation_config = vllm_config.compilation_config
|
||||
if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE:
|
||||
if (envs.VLLM_USE_V1 and vllm_config.compilation_config.level
|
||||
== CompilationLevel.PIECEWISE):
|
||||
|
||||
# Note: vLLM V1 is using PIECEWISE level compilation, which will
|
||||
# take time to compile kernels just-in-time with the inductor
|
||||
|
|
|
@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Optional, Union, cast
|
|||
import torch
|
||||
from tpu_info import device
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.inputs import ProcessorInputs, PromptType
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
|
@ -49,10 +50,12 @@ class TpuPlatform(Platform):
|
|||
and selected_backend != _Backend.PALLAS_VLLM_V1):
|
||||
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
||||
|
||||
if not use_v1:
|
||||
raise ValueError("TPU backend only supports V1.")
|
||||
logger.info("Using Pallas V1 backend.")
|
||||
return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
|
||||
if use_v1:
|
||||
logger.info("Using Pallas V1 backend.")
|
||||
return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
|
||||
else:
|
||||
logger.info("Using Pallas backend.")
|
||||
return "vllm.attention.backends.pallas.PallasAttentionBackend"
|
||||
|
||||
@classmethod
|
||||
def get_device_name(cls, device_id: int = 0) -> str:
|
||||
|
@ -65,7 +68,7 @@ class TpuPlatform(Platform):
|
|||
|
||||
@classmethod
|
||||
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
||||
return False
|
||||
return not envs.VLLM_USE_V1
|
||||
|
||||
@classmethod
|
||||
def get_punica_wrapper(cls) -> str:
|
||||
|
@ -114,19 +117,31 @@ class TpuPlatform(Platform):
|
|||
"Using bfloat16 instead.", vllm_config.model_config.dtype)
|
||||
vllm_config.model_config.dtype = torch.bfloat16
|
||||
|
||||
from vllm.v1.attention.backends.pallas import PallasAttentionBackend
|
||||
cache_config.block_size = PallasAttentionBackend.get_page_size(
|
||||
vllm_config) # type: ignore[assignment]
|
||||
if envs.VLLM_USE_V1:
|
||||
from vllm.v1.attention.backends.pallas import (
|
||||
PallasAttentionBackend)
|
||||
cache_config.block_size = PallasAttentionBackend.get_page_size(
|
||||
vllm_config) # type: ignore[assignment]
|
||||
|
||||
parallel_config = vllm_config.parallel_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
if parallel_config.worker_cls == "auto":
|
||||
if scheduler_config.is_multi_step:
|
||||
raise NotImplementedError(
|
||||
"Multi-step scheduling is not supported (and not "
|
||||
"needed) on vLLM V1. Please launch without "
|
||||
"--num-scheduler-steps.")
|
||||
parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TPUWorker"
|
||||
if envs.VLLM_USE_V1:
|
||||
raise NotImplementedError(
|
||||
"Multi-step scheduling is not supported (and not "
|
||||
"needed) on vLLM V1. Please launch without "
|
||||
"--num-scheduler-steps.")
|
||||
else:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
|
||||
else:
|
||||
if envs.VLLM_USE_V1:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.v1.worker.tpu_worker.TPUWorker"
|
||||
else:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.worker.tpu_worker.TPUWorker"
|
||||
|
||||
assert not vllm_config.speculative_config, (
|
||||
"Speculative decoding is not yet supported for TPU backend")
|
||||
|
@ -174,9 +189,13 @@ class TpuPlatform(Platform):
|
|||
processed_inputs: ProcessorInputs,
|
||||
) -> None:
|
||||
"""Raises if this request is unsupported on this platform"""
|
||||
if (isinstance(params, SamplingParams)
|
||||
and params.sampling_type == SamplingType.RANDOM_SEED):
|
||||
raise ValueError("Torch XLA does not support per-request seed.")
|
||||
if isinstance(params, SamplingParams):
|
||||
if params.guided_decoding is not None and not envs.VLLM_USE_V1:
|
||||
raise ValueError("Structured output is not supported on "
|
||||
f"{cls.device_name} V0.")
|
||||
if params.sampling_type == SamplingType.RANDOM_SEED:
|
||||
raise ValueError(
|
||||
"Torch XLA does not support per-request seed.")
|
||||
|
||||
|
||||
try:
|
||||
|
|
|
@ -39,10 +39,12 @@ class XPUPlatform(Platform):
|
|||
if selected_backend != _Backend.IPEX:
|
||||
logger.info("Cannot use %s backend on XPU.", selected_backend)
|
||||
use_v1 = envs.VLLM_USE_V1
|
||||
if not use_v1:
|
||||
raise ValueError("XPU backend only supports V1.")
|
||||
logger.info("Using Flash Attention backend on V1 engine.")
|
||||
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
||||
if use_v1:
|
||||
logger.info("Using Flash Attention backend on V1 engine.")
|
||||
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
||||
else:
|
||||
logger.info("Using IPEX attention backend.")
|
||||
return "vllm.attention.backends.ipex_attn.IpexAttnBackend"
|
||||
|
||||
@classmethod
|
||||
def get_device_capability(
|
||||
|
@ -75,7 +77,10 @@ class XPUPlatform(Platform):
|
|||
cache_config = vllm_config.cache_config
|
||||
# in V1(or with ipex chunked prefill) block_size is 64
|
||||
if cache_config and cache_config.block_size is None:
|
||||
cache_config.block_size = 64
|
||||
if envs.VLLM_USE_V1:
|
||||
cache_config.block_size = 64
|
||||
else:
|
||||
cache_config.block_size = 16
|
||||
|
||||
# Instances created using VllmConfig() typically have model_config as
|
||||
# None by default. The modification involves adding a check to prevent
|
||||
|
@ -101,7 +106,11 @@ class XPUPlatform(Platform):
|
|||
|
||||
# check and update parallel config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker"
|
||||
if envs.VLLM_USE_V1:
|
||||
parallel_config.worker_cls =\
|
||||
"vllm.v1.worker.xpu_worker.XPUWorker"
|
||||
else:
|
||||
parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker"
|
||||
|
||||
if parallel_config.distributed_executor_backend is None:
|
||||
if parallel_config.world_size > 1:
|
||||
|
|
|
@ -0,0 +1,326 @@
|
|||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, cast
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
|
||||
from vllm.utils import make_tensor_with_pad
|
||||
from vllm.worker.cpu_model_runner import (CPUModelRunnerBase,
|
||||
ModelInputForCPUBuilder,
|
||||
ModelInputForCPUWithSamplingMetadata)
|
||||
from vllm.worker.model_runner_base import (
|
||||
_add_attn_metadata_broadcastable_dict,
|
||||
_add_sampling_metadata_broadcastable_dict)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class EncoderDecoderModelInputForCPU(ModelInputForCPUWithSamplingMetadata):
|
||||
"""
|
||||
Used by the EncoderDecoderModelRunner.
|
||||
"""
|
||||
encoder_input_tokens: Optional[torch.Tensor] = None
|
||||
encoder_input_positions: Optional[torch.Tensor] = None
|
||||
|
||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||
tensor_dict = {
|
||||
"input_tokens": self.input_tokens,
|
||||
"input_positions": self.input_positions,
|
||||
"encoder_input_tokens": self.encoder_input_tokens,
|
||||
"encoder_input_positions": self.encoder_input_positions,
|
||||
"multi_modal_kwargs": self.multi_modal_kwargs,
|
||||
}
|
||||
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||
_add_sampling_metadata_broadcastable_dict(tensor_dict,
|
||||
self.sampling_metadata)
|
||||
return tensor_dict
|
||||
|
||||
@classmethod
|
||||
def from_broadcasted_tensor_dict(
|
||||
cls,
|
||||
tensor_dict: Dict[str, Any],
|
||||
attn_backend: Optional["AttentionBackend"] = None,
|
||||
) -> "EncoderDecoderModelInputForCPU":
|
||||
return cast(
|
||||
EncoderDecoderModelInputForCPU,
|
||||
super().from_broadcasted_tensor_dict(tensor_dict, attn_backend))
|
||||
|
||||
|
||||
class CPUEncoderDecoderModelRunner(
|
||||
CPUModelRunnerBase[EncoderDecoderModelInputForCPU]):
|
||||
_model_input_cls: Type[EncoderDecoderModelInputForCPU] = (
|
||||
EncoderDecoderModelInputForCPU)
|
||||
_builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder
|
||||
|
||||
def _list_to_int32_tensor(
|
||||
self,
|
||||
_list: List[int],
|
||||
) -> torch.Tensor:
|
||||
return torch.tensor(_list, dtype=torch.int32, device=self.device)
|
||||
|
||||
def _list_to_long_tensor(
|
||||
self,
|
||||
_list: List[int],
|
||||
) -> torch.Tensor:
|
||||
return torch.tensor(_list, dtype=torch.long, device=self.device)
|
||||
|
||||
def _empty_int32_tensor(self) -> torch.Tensor:
|
||||
return self._list_to_int32_tensor([])
|
||||
|
||||
def _empty_long_tensor(self) -> torch.Tensor:
|
||||
return self._list_to_long_tensor([])
|
||||
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
self, tensor_dict: Dict[str,
|
||||
Any]) -> EncoderDecoderModelInputForCPU:
|
||||
return EncoderDecoderModelInputForCPU.from_broadcasted_tensor_dict(
|
||||
tensor_dict,
|
||||
attn_backend=self.attn_backend,
|
||||
)
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
) -> EncoderDecoderModelInputForCPU:
|
||||
model_input = self._prepare_model_input_tensors(
|
||||
seq_group_metadata_list, finished_requests_ids)
|
||||
(
|
||||
attn_metadata,
|
||||
encoder_input_tokens_tensor,
|
||||
encoder_input_positions_tensor,
|
||||
) = self._prepare_encoder_model_input_tensors(seq_group_metadata_list,
|
||||
model_input)
|
||||
# Sampling metadata is only required for the final pp group
|
||||
generators = self.get_generators(finished_requests_ids)
|
||||
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
|
||||
model_input.seq_lens,
|
||||
model_input.query_lens,
|
||||
self.device,
|
||||
pin_memory=False,
|
||||
generators=generators)
|
||||
return dataclasses.replace(
|
||||
model_input,
|
||||
sampling_metadata=sampling_metadata,
|
||||
attn_metadata=attn_metadata,
|
||||
encoder_input_tokens=encoder_input_tokens_tensor,
|
||||
encoder_input_positions=encoder_input_positions_tensor,
|
||||
virtual_engine=virtual_engine,
|
||||
)
|
||||
|
||||
def _prepare_encoder_model_input_tensors(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
model_input: EncoderDecoderModelInputForCPU,
|
||||
) -> Tuple[AttentionMetadata, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
"""Helper method to prepare the encoder- and cross-attn-related
|
||||
model inputs based on a given sequence group. These additional inputs
|
||||
are used to augment an already-computed `EncoderDecoderModelInput`
|
||||
data structure which already has decoder-related model inputs
|
||||
populated.
|
||||
|
||||
Sets the following attn_metadata fields:
|
||||
* `num_encoder_tokens`
|
||||
* `encoder_seq_lens`
|
||||
* `encoder_seq_lens_tensor`
|
||||
* `max_encoder_seq_len`
|
||||
* `cross_slot_mapping`
|
||||
* `cross_block_tables`
|
||||
|
||||
Constructs a new model inputs data structure, based on
|
||||
(1) the existing fields in the `model_inputs` argument,
|
||||
and (2) the following additional fields which are
|
||||
computed (or in the case of `attn_metadata`, updated)
|
||||
by this function:
|
||||
* attn_metadata
|
||||
* encoder_input_tokens
|
||||
* encoder_input_positions
|
||||
|
||||
Arguments:
|
||||
|
||||
* seq_group_metadata_list: list of sequence groups for which to
|
||||
compute inputs
|
||||
* model_inputs: model inputs data structure with decoder-oriented
|
||||
fields already computed.
|
||||
|
||||
Return:
|
||||
|
||||
* Updated model inputs data structure
|
||||
"""
|
||||
|
||||
if len(seq_group_metadata_list) == 0:
|
||||
return (model_input.attn_metadata, None, None)
|
||||
|
||||
# Since we are not supporting chunked prefill either the entire
|
||||
# batch is prefill or it is decode
|
||||
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||
|
||||
# Build encoder inputs
|
||||
encoder_seq_lens: List[int] = []
|
||||
if is_prompt:
|
||||
# Prefill phase.
|
||||
cross_block_tables = self._empty_int32_tensor().view(
|
||||
len(seq_group_metadata_list), -1)
|
||||
|
||||
# Extract input tokens/positions, cross-attention slot-mapping,
|
||||
# & seq len from each sequence group metadata
|
||||
(
|
||||
encoder_input_tokens,
|
||||
encoder_input_positions,
|
||||
cross_slot_mapping,
|
||||
) = (
|
||||
[],
|
||||
[],
|
||||
[],
|
||||
)
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
# Build seq lens
|
||||
seq_len = seq_group_metadata.encoder_seq_data.get_len()
|
||||
token_ids = seq_group_metadata.encoder_seq_data.get_token_ids()
|
||||
encoder_seq_lens.append(seq_len)
|
||||
|
||||
# Build slot mapping
|
||||
for i in range(0, seq_len):
|
||||
block_number = seq_group_metadata.cross_block_table[
|
||||
i // self.block_size]
|
||||
block_offset = i % self.block_size
|
||||
slot = block_number * self.block_size + block_offset
|
||||
cross_slot_mapping.append(slot)
|
||||
|
||||
# Build encoder input tokens
|
||||
encoder_input_tokens.extend(token_ids)
|
||||
encoder_input_positions.extend(list(range(0, seq_len)))
|
||||
|
||||
# Convert tokens/positions & cross-attention
|
||||
# slot-mapping to encoder input tensors
|
||||
encoder_input_tokens_tensor = self._list_to_long_tensor(
|
||||
encoder_input_tokens)
|
||||
encoder_input_positions_tensor = self._list_to_long_tensor(
|
||||
encoder_input_positions)
|
||||
cross_slot_mapping_tensor = self._list_to_long_tensor(
|
||||
cross_slot_mapping)
|
||||
|
||||
else:
|
||||
# Decode phase.
|
||||
encoder_input_tokens_tensor = self._empty_long_tensor()
|
||||
encoder_input_positions_tensor = self._empty_long_tensor()
|
||||
cross_slot_mapping_tensor = self._empty_long_tensor()
|
||||
# Extract cross-attention block tables &
|
||||
# seq len from each sequence group metadata.
|
||||
# Cross-attention block tables are empty
|
||||
# during vLLM memory profiling.
|
||||
cross_block_tables = []
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
for _ in range(len(seq_group_metadata.seq_data)):
|
||||
encoder_seq_lens.append(
|
||||
seq_group_metadata.encoder_seq_data.get_len())
|
||||
cross_block_table = seq_group_metadata.cross_block_table
|
||||
cross_block_tables.append([] if (
|
||||
cross_block_table is None) else cross_block_table)
|
||||
|
||||
max_len_of_block_table = max(
|
||||
len(block_table) for block_table in cross_block_tables)
|
||||
|
||||
cross_block_tables = make_tensor_with_pad(
|
||||
cross_block_tables,
|
||||
max_len=max_len_of_block_table,
|
||||
pad=0,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
# Compute encoder sequence lengths & encoder
|
||||
# sequence starting offset tensors
|
||||
max_encoder_seq_len = max(encoder_seq_lens, default=0)
|
||||
encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens)
|
||||
encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] +
|
||||
1,
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
torch.cumsum(encoder_seq_lens_tensor,
|
||||
dim=0,
|
||||
dtype=encoder_seq_start_loc.dtype,
|
||||
out=encoder_seq_start_loc[1:])
|
||||
|
||||
# Update attention metadata with encoder-oriented attributes
|
||||
attn_metadata = model_input.attn_metadata
|
||||
assert attn_metadata is not None
|
||||
(
|
||||
attn_metadata.num_encoder_tokens,
|
||||
attn_metadata.encoder_seq_lens,
|
||||
attn_metadata.encoder_seq_lens_tensor,
|
||||
attn_metadata.max_encoder_seq_len,
|
||||
attn_metadata.cross_slot_mapping,
|
||||
attn_metadata.cross_block_tables,
|
||||
) = (
|
||||
sum(encoder_seq_lens),
|
||||
encoder_seq_lens,
|
||||
encoder_seq_lens_tensor,
|
||||
max_encoder_seq_len,
|
||||
cross_slot_mapping_tensor,
|
||||
cross_block_tables,
|
||||
)
|
||||
|
||||
return (attn_metadata, encoder_input_tokens_tensor,
|
||||
encoder_input_positions_tensor)
|
||||
|
||||
@torch.no_grad()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: EncoderDecoderModelInputForCPU,
|
||||
kv_caches: List[torch.Tensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
if num_steps > 1:
|
||||
raise ValueError(
|
||||
"CPU worker does not support multi-step execution.")
|
||||
|
||||
model_executable = self.model
|
||||
execute_model_kwargs = {
|
||||
"input_ids":
|
||||
model_input.input_tokens,
|
||||
"positions":
|
||||
model_input.input_positions,
|
||||
"encoder_input_ids":
|
||||
model_input.encoder_input_tokens,
|
||||
"encoder_positions":
|
||||
model_input.encoder_input_positions,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
model_input.multi_modal_kwargs or {},
|
||||
device=self.device,
|
||||
),
|
||||
"intermediate_tensors":
|
||||
intermediate_tensors,
|
||||
}
|
||||
|
||||
with set_forward_context(model_input.attn_metadata, self.vllm_config,
|
||||
model_input.virtual_engine):
|
||||
hidden_states = model_executable(**execute_model_kwargs)
|
||||
|
||||
# Compute the logits.
|
||||
logits = self.model.compute_logits(hidden_states,
|
||||
model_input.sampling_metadata)
|
||||
|
||||
# Only perform sampling in the driver worker.
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
|
||||
# Sample the next token.
|
||||
output = self.sampler(
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
return [output]
|
|
@ -0,0 +1,671 @@
|
|||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
import weakref
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Type,
|
||||
TypeVar, Union)
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models import supports_lora, supports_multimodal
|
||||
from vllm.multimodal import (BatchedTensorInputs, MultiModalKwargs,
|
||||
MultiModalPlaceholderMap)
|
||||
from vllm.sequence import (IntermediateTensors, SequenceData,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.worker.model_runner_base import (
|
||||
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
|
||||
_add_attn_metadata_broadcastable_dict,
|
||||
_add_sampling_metadata_broadcastable_dict,
|
||||
_init_attn_metadata_from_tensor_dict,
|
||||
_init_sampling_metadata_from_tensor_dict)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
TModelInputForCPU = TypeVar('TModelInputForCPU', bound="ModelInputForCPU")
|
||||
_PAD_SLOT_ID = -1
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelInputForCPU(ModelRunnerInputBase):
|
||||
"""
|
||||
Base class contains metadata needed for the base model forward pass on CPU
|
||||
"""
|
||||
input_tokens: Optional[torch.Tensor] = None
|
||||
input_positions: Optional[torch.Tensor] = None
|
||||
token_type_ids: Optional[torch.Tensor] = None
|
||||
attn_metadata: Optional["AttentionMetadata"] = None
|
||||
multi_modal_kwargs: Optional[BatchedTensorInputs] = None
|
||||
virtual_engine: Optional[int] = None
|
||||
seq_lens: Optional[List[int]] = None
|
||||
query_lens: Optional[List[int]] = None
|
||||
lora_mapping: Optional["LoRAMapping"] = None
|
||||
lora_requests: Optional[Set[LoRARequest]] = None
|
||||
|
||||
def as_broadcastable_tensor_dict(
|
||||
self) -> Dict[str, Union[int, torch.Tensor]]:
|
||||
tensor_dict = {
|
||||
"input_tokens": self.input_tokens,
|
||||
"input_positions": self.input_positions,
|
||||
"token_type_ids": self.token_type_ids,
|
||||
"multi_modal_kwargs": self.multi_modal_kwargs,
|
||||
"lora_requests": self.lora_requests,
|
||||
"lora_mapping": self.lora_mapping,
|
||||
}
|
||||
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||
|
||||
return tensor_dict
|
||||
|
||||
@classmethod
|
||||
def from_broadcasted_tensor_dict(
|
||||
cls: Type[TModelInputForCPU],
|
||||
tensor_dict: Dict[str, Any],
|
||||
attn_backend: Optional["AttentionBackend"] = None
|
||||
) -> TModelInputForCPU:
|
||||
if attn_backend is not None:
|
||||
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
||||
attn_backend, tensor_dict)
|
||||
return cls(**tensor_dict)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU):
|
||||
"""
|
||||
Used by the ModelRunner.
|
||||
"""
|
||||
sampling_metadata: Optional["SamplingMetadata"] = None
|
||||
is_prompt: Optional[bool] = None
|
||||
|
||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||
tensor_dict = {
|
||||
"input_tokens": self.input_tokens,
|
||||
"input_positions": self.input_positions,
|
||||
"token_type_ids": self.token_type_ids,
|
||||
"multi_modal_kwargs": self.multi_modal_kwargs,
|
||||
}
|
||||
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||
_add_sampling_metadata_broadcastable_dict(tensor_dict,
|
||||
self.sampling_metadata)
|
||||
return tensor_dict
|
||||
|
||||
@classmethod
|
||||
def from_broadcasted_tensor_dict(
|
||||
cls,
|
||||
tensor_dict: Dict[str, Any],
|
||||
attn_backend: Optional["AttentionBackend"] = None,
|
||||
) -> "ModelInputForCPUWithSamplingMetadata":
|
||||
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
|
||||
if attn_backend is not None:
|
||||
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
||||
attn_backend, tensor_dict)
|
||||
return cls(**tensor_dict)
|
||||
|
||||
|
||||
class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
||||
|
||||
class ModelInputData:
|
||||
|
||||
def __init__(self, use_mrope: bool):
|
||||
self.use_mrope = use_mrope
|
||||
self.input_tokens: List[int] = []
|
||||
self.input_positions: List[int] = []
|
||||
self.token_type_ids: Optional[List[int]] = []
|
||||
self.seq_lens: List[int] = []
|
||||
self.query_lens: List[int] = []
|
||||
self.prefill_block_tables: List[List[int]] = []
|
||||
self.decode_block_tables: List[List[int]] = []
|
||||
self.max_decode_seq_len: int = 0
|
||||
self.num_prefills: int = 0
|
||||
self.num_prefill_tokens: int = 0
|
||||
self.num_decode_tokens: int = 0
|
||||
self.slot_mapping: List[int] = []
|
||||
self.multi_modal_inputs_list: List[MultiModalKwargs] = []
|
||||
self.multi_modal_placeholder_maps: Dict[
|
||||
str, MultiModalPlaceholderMap] = defaultdict(
|
||||
MultiModalPlaceholderMap)
|
||||
self.input_mrope_positions: List[List[int]] = [[]
|
||||
for _ in range(3)]
|
||||
|
||||
def __init__(self,
|
||||
runner: "CPUModelRunner",
|
||||
finished_requests_ids: Optional[List[str]] = None) -> None:
|
||||
super().__init__()
|
||||
self.runner = runner
|
||||
self.chunked_prefill = (runner.scheduler_config.chunked_prefill_enabled
|
||||
or runner.cache_config.enable_prefix_caching)
|
||||
self.model_input_cls = self.runner._model_input_cls
|
||||
self.attn_backend = self.runner.attn_backend
|
||||
self.sliding_window = self.runner.sliding_window
|
||||
self.block_size = self.runner.block_size
|
||||
self.device = self.runner.device
|
||||
self.enable_lora = self.runner.lora_config is not None
|
||||
if self.runner.attn_backend is not None:
|
||||
# spec decode (e.g. Medusa) does not have atten backend
|
||||
attn_backend = self.runner.attn_backend
|
||||
self.att_metadata_builder = attn_backend.get_builder_cls()(self)
|
||||
|
||||
def prepare(self,
|
||||
finished_requests_ids: Optional[List[str]] = None) -> None:
|
||||
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
self.input_data = ModelInputForCPUBuilder.ModelInputData(
|
||||
self.runner.model_config.uses_mrope)
|
||||
self.att_metadata_builder.prepare()
|
||||
|
||||
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
|
||||
self.seq_group_metadata_list.append(seq_group_metadata)
|
||||
|
||||
def set_seq_group_list(
|
||||
self, seq_group_metadata_list: List[SequenceGroupMetadata]):
|
||||
self.seq_group_metadata_list = seq_group_metadata_list
|
||||
|
||||
def build(self) -> ModelInputForCPU:
|
||||
self._build_input_data()
|
||||
|
||||
input_data = self.input_data
|
||||
input_tokens = torch.tensor(input_data.input_tokens,
|
||||
dtype=torch.long,
|
||||
device="cpu")
|
||||
input_positions = torch.tensor(
|
||||
input_data.input_positions
|
||||
if not any(input_data.input_mrope_positions) else
|
||||
input_data.input_mrope_positions,
|
||||
dtype=torch.long,
|
||||
device="cpu")
|
||||
token_type_ids = torch.tensor(input_data.token_type_ids,
|
||||
dtype=torch.long,
|
||||
device="cpu") \
|
||||
if input_data.token_type_ids else None
|
||||
|
||||
# For multi-modal models
|
||||
multi_modal_kwargs = None
|
||||
if len(input_data.multi_modal_inputs_list) != 0:
|
||||
multi_modal_kwargs = MultiModalKwargs.batch(
|
||||
input_data.multi_modal_inputs_list)
|
||||
|
||||
attn_metadata = self.att_metadata_builder.build(
|
||||
input_data.seq_lens, input_data.query_lens, -1, -1)
|
||||
|
||||
is_prompt = (self.seq_group_metadata_list[0].is_prompt
|
||||
if self.seq_group_metadata_list else None)
|
||||
# LoRA data.
|
||||
lora_requests = set()
|
||||
lora_mapping = None
|
||||
if self.enable_lora:
|
||||
lora_requests = set(seq.lora_request
|
||||
for seq in self.seq_group_metadata_list
|
||||
if seq.lora_request is not None)
|
||||
|
||||
lora_mapping = self._prepare_lora_input(
|
||||
self.seq_group_metadata_list, is_prompt)
|
||||
|
||||
return self.model_input_cls(input_tokens=input_tokens,
|
||||
input_positions=input_positions,
|
||||
token_type_ids=token_type_ids,
|
||||
seq_lens=input_data.seq_lens,
|
||||
query_lens=input_data.query_lens,
|
||||
attn_metadata=attn_metadata,
|
||||
multi_modal_kwargs=multi_modal_kwargs,
|
||||
lora_mapping=lora_mapping,
|
||||
lora_requests=lora_requests)
|
||||
|
||||
def _build_input_data(self):
|
||||
for seq_group_metadata in self.seq_group_metadata_list:
|
||||
for seq_id, seq_data in seq_group_metadata.seq_data.items():
|
||||
if seq_group_metadata.is_prompt:
|
||||
self._compute_prompt_input_tokens(self.input_data,
|
||||
seq_group_metadata,
|
||||
seq_data, seq_id)
|
||||
if seq_group_metadata.multi_modal_data:
|
||||
self._compute_multi_modal_input(
|
||||
seq_group_metadata, seq_data)
|
||||
else:
|
||||
self._compute_decode_input_tokens(self.input_data,
|
||||
seq_group_metadata,
|
||||
seq_data, seq_id)
|
||||
|
||||
def _compute_decode_input_tokens(self, data: ModelInputData,
|
||||
seq_group_metadata: SequenceGroupMetadata,
|
||||
seq_data: SequenceData, seq_id: int):
|
||||
"""
|
||||
Compute decode input tokens, positions, block table and slot mapping.
|
||||
"""
|
||||
block_size = self.runner.block_size
|
||||
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
seq_len = seq_data.get_len()
|
||||
context_len = seq_data.get_num_computed_tokens()
|
||||
|
||||
tokens = seq_data.get_last_token_id()
|
||||
token_positions = seq_len - 1
|
||||
block_number = block_table[token_positions // block_size]
|
||||
block_offset = token_positions % block_size
|
||||
slot = block_number * block_size + block_offset
|
||||
|
||||
# For paged_attention kernel
|
||||
if self.runner.sliding_window:
|
||||
start_idx = max(0, seq_len - self.runner.sliding_window)
|
||||
start_block = start_idx // block_size
|
||||
start_idx = start_block * block_size
|
||||
seq_len = seq_len - start_idx
|
||||
block_table = block_table[start_block:]
|
||||
|
||||
# For MRotaryEmbedding
|
||||
if seq_data.mrope_position_delta is not None:
|
||||
next_pos = MRotaryEmbedding.get_next_input_positions(
|
||||
seq_data.mrope_position_delta,
|
||||
context_len,
|
||||
seq_len,
|
||||
)
|
||||
for idx in range(3):
|
||||
data.input_mrope_positions[idx].extend( # type: ignore
|
||||
next_pos[idx])
|
||||
else:
|
||||
data.input_positions.append(token_positions) # type: ignore
|
||||
|
||||
# Update fields
|
||||
data.input_tokens.append(tokens)
|
||||
data.max_decode_seq_len = max(data.max_decode_seq_len, seq_len)
|
||||
data.num_decode_tokens += 1
|
||||
data.slot_mapping.append(slot)
|
||||
data.decode_block_tables.append(block_table)
|
||||
data.query_lens.append(1)
|
||||
data.seq_lens.append(seq_len)
|
||||
|
||||
def _compute_prompt_input_tokens(self, data: ModelInputData,
|
||||
seq_group_metadata: SequenceGroupMetadata,
|
||||
seq_data: SequenceData, seq_id: int):
|
||||
"""
|
||||
Compute prompt input tokens, positions, block table and slot mapping.
|
||||
"""
|
||||
token_chunk_size = seq_group_metadata.token_chunk_size
|
||||
block_size = self.runner.block_size
|
||||
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
seq_len = seq_data.get_len()
|
||||
context_len = seq_data.get_num_computed_tokens()
|
||||
seq_len = min(seq_len, context_len + token_chunk_size)
|
||||
|
||||
# For prefix caching
|
||||
prefix_cache_block_num = len(seq_group_metadata.computed_block_nums)
|
||||
if prefix_cache_block_num > 0:
|
||||
prefix_cache_len = (prefix_cache_block_num *
|
||||
self.runner.block_size)
|
||||
if prefix_cache_len <= context_len:
|
||||
# We already passed the cache hit region,
|
||||
# so do normal computation.
|
||||
pass
|
||||
elif context_len < prefix_cache_len < seq_len:
|
||||
# Partial hit. Compute the missing part.
|
||||
context_len = prefix_cache_len
|
||||
token_chunk_size = seq_len - context_len
|
||||
elif seq_len <= prefix_cache_len:
|
||||
# Full hit. Only compute the last token to avoid
|
||||
# erroneous behavior. FIXME: Ideally we should directly
|
||||
# mark all tokens as computed in the scheduler and do not
|
||||
# schedule this sequence, so this case should not happen.
|
||||
context_len = seq_len - 1
|
||||
token_chunk_size = 1
|
||||
|
||||
tokens = seq_data.get_token_ids()
|
||||
tokens = tokens[context_len:seq_len]
|
||||
token_positions = range(context_len, seq_len)
|
||||
token_types = seq_group_metadata.token_type_ids
|
||||
|
||||
# For encoder-only models, the block_table is None,
|
||||
# and there is no need to initialize the slot_mapping.
|
||||
if block_table is not None:
|
||||
slot_mapping = [_PAD_SLOT_ID] * len(token_positions)
|
||||
for i, pos in enumerate(token_positions):
|
||||
block_number = block_table[pos // block_size]
|
||||
block_offset = pos % block_size
|
||||
slot = block_number * block_size + block_offset
|
||||
slot_mapping[i] = slot
|
||||
data.slot_mapping.extend(slot_mapping)
|
||||
|
||||
# The MROPE positions are prepared in _compute_multi_modal_input
|
||||
data.input_positions.extend(token_positions)
|
||||
|
||||
if data.token_type_ids is not None:
|
||||
data.token_type_ids.extend(token_types if token_types else [])
|
||||
|
||||
# Update fields
|
||||
data.input_tokens.extend(tokens)
|
||||
data.num_prefills += 1
|
||||
data.num_prefill_tokens += len(tokens)
|
||||
data.query_lens.append(len(tokens))
|
||||
data.prefill_block_tables.append(block_table)
|
||||
data.seq_lens.append(seq_len)
|
||||
|
||||
def _compute_multi_modal_input(self,
|
||||
seq_group_metadata: SequenceGroupMetadata,
|
||||
seq_data: SequenceData):
|
||||
computed_len = seq_data.get_num_computed_tokens()
|
||||
seq_len = self.input_data.seq_lens[-1]
|
||||
|
||||
# NOTE: mm_kwargs only includes the subset of multi-modal items that
|
||||
# intersect with the current prefill positions.
|
||||
mm_kwargs, placeholder_maps = MultiModalPlaceholderMap.from_seq_group(
|
||||
seq_group_metadata, range(computed_len, seq_len))
|
||||
|
||||
if not mm_kwargs:
|
||||
return
|
||||
|
||||
# special processing for mrope position deltas.
|
||||
if self.runner.model_config.uses_mrope:
|
||||
assert not self.chunked_prefill, \
|
||||
"MROPE on CPU does not support chunked-prefill."
|
||||
|
||||
image_grid_thw = mm_kwargs.get("image_grid_thw", None)
|
||||
video_grid_thw = mm_kwargs.get("video_grid_thw", None)
|
||||
audio_feature_lengths = mm_kwargs.get("audio_feature_lengths",
|
||||
None)
|
||||
assert (
|
||||
image_grid_thw is not None or video_grid_thw is not None
|
||||
or audio_feature_lengths is not None), (
|
||||
"mrope embedding type requires multi-modal input mapper "
|
||||
"returns 'image_grid_thw' or 'video_grid_thw' or "
|
||||
"'audio_feature_lengths'.")
|
||||
|
||||
second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None)
|
||||
use_audio_in_video = mm_kwargs.get("use_audio_in_video", False)
|
||||
hf_config = self.runner.model_config.hf_config
|
||||
token_ids = seq_data.get_token_ids()
|
||||
|
||||
mrope_positions, mrope_position_delta = \
|
||||
MRotaryEmbedding.get_input_positions(
|
||||
token_ids,
|
||||
hf_config=hf_config,
|
||||
image_grid_thw=image_grid_thw,
|
||||
video_grid_thw=video_grid_thw,
|
||||
second_per_grid_ts=second_per_grid_ts,
|
||||
context_len=computed_len,
|
||||
audio_feature_lengths=audio_feature_lengths,
|
||||
use_audio_in_video=use_audio_in_video,
|
||||
)
|
||||
seq_data.mrope_position_delta = mrope_position_delta
|
||||
|
||||
for i in range(3):
|
||||
self.input_data.input_mrope_positions[ # type: ignore
|
||||
i].extend(mrope_positions[i])
|
||||
|
||||
self.input_data.multi_modal_inputs_list.append(mm_kwargs)
|
||||
for modality, placeholder_map in placeholder_maps.items():
|
||||
self.input_data.multi_modal_placeholder_maps[modality].extend(
|
||||
placeholder_map)
|
||||
|
||||
def _prepare_lora_input(
|
||||
self, seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
is_prefill: bool) -> LoRAMapping:
|
||||
index_mapping = []
|
||||
prompt_mapping = []
|
||||
for seq in seq_group_metadata_list:
|
||||
lora_id = seq.lora_int_id
|
||||
query_len = seq.token_chunk_size
|
||||
|
||||
index_mapping += [lora_id] * query_len
|
||||
prompt_mapping += [lora_id] * (
|
||||
query_len if seq.sampling_params
|
||||
and seq.sampling_params.prompt_logprobs is not None else 1)
|
||||
|
||||
return LoRAMapping(index_mapping=tuple(index_mapping),
|
||||
prompt_mapping=tuple(prompt_mapping),
|
||||
is_prefill=is_prefill)
|
||||
|
||||
|
||||
class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
|
||||
"""
|
||||
Helper class for shared methods between CPU model runners.
|
||||
"""
|
||||
_model_input_cls: Type[TModelInputForCPU]
|
||||
_builder_cls: Type[ModelInputForCPUBuilder]
|
||||
builder: ModelInputForCPUBuilder
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
return_hidden_states: bool = False,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
ModelRunnerBase.__init__(self, vllm_config)
|
||||
model_config = self.model_config
|
||||
cache_config = self.cache_config
|
||||
|
||||
self.is_driver_worker = is_driver_worker
|
||||
self.return_hidden_states = return_hidden_states
|
||||
|
||||
self.device = self.device_config.device
|
||||
self.pin_memory = False
|
||||
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.sliding_window = model_config.get_sliding_window()
|
||||
self.block_size = cache_config.block_size
|
||||
num_attn_heads = self.model_config.get_num_attention_heads(
|
||||
self.parallel_config)
|
||||
needs_attn_backend = (num_attn_heads != 0
|
||||
or self.model_config.is_attention_free)
|
||||
self.attn_backend = get_attn_backend(
|
||||
self.model_config.get_head_size(),
|
||||
self.model_config.dtype,
|
||||
self.kv_cache_dtype,
|
||||
self.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
use_mla=self.model_config.use_mla,
|
||||
) if needs_attn_backend else None
|
||||
|
||||
# Lazy initialization.
|
||||
self.model: nn.Module # Set after init_Model
|
||||
# Set after load_model.
|
||||
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
|
||||
self.sampler = get_sampler()
|
||||
|
||||
if hasattr(self, "_builder_cls"):
|
||||
# multi-step model runner does not have `_builder_cls`
|
||||
self.builder = self._builder_cls(weakref.proxy(self))
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.model = get_model(vllm_config=self.vllm_config)
|
||||
|
||||
if self.lora_config:
|
||||
assert supports_lora(
|
||||
self.model
|
||||
), f"{self.model.__class__.__name__} does not support LoRA yet."
|
||||
|
||||
if supports_multimodal(self.model):
|
||||
logger.warning("Regarding multimodal models, vLLM currently "
|
||||
"only supports adding LoRA to language model.")
|
||||
|
||||
# Use get_text_config() in case of multimodal models
|
||||
text_config = self.model_config.hf_config.get_text_config()
|
||||
|
||||
self.lora_manager = LRUCacheWorkerLoRAManager(
|
||||
self.scheduler_config.max_num_seqs,
|
||||
self.scheduler_config.max_num_batched_tokens,
|
||||
self.vocab_size,
|
||||
self.lora_config,
|
||||
self.device,
|
||||
self.model.embedding_modules,
|
||||
self.model.embedding_padding_modules,
|
||||
max_position_embeddings=text_config.max_position_embeddings,
|
||||
)
|
||||
self.model = self.lora_manager.create_lora_manager(self.model)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
def _prepare_model_input_tensors(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
) -> TModelInputForCPU:
|
||||
"""Helper method to prepare the model input based on a given sequence
|
||||
group. Prepares metadata needed for the base model forward pass but not
|
||||
metadata for possible additional steps, e.g., sampling.
|
||||
|
||||
"""
|
||||
self.builder.prepare(finished_requests_ids)
|
||||
self.builder.set_seq_group_list(seq_group_metadata_list)
|
||||
|
||||
return self.builder.build() # type: ignore
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self.model_config.get_vocab_size()
|
||||
|
||||
def remove_all_loras(self):
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
self.lora_manager.remove_all_adapters()
|
||||
|
||||
def set_active_loras(self, lora_requests: Set[LoRARequest],
|
||||
lora_mapping: LoRAMapping) -> None:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.add_adapter(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.remove_adapter(lora_id)
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.pin_adapter(lora_id)
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
if not self.lora_manager:
|
||||
raise RuntimeError("LoRA is not enabled.")
|
||||
return self.lora_manager.list_adapters()
|
||||
|
||||
|
||||
class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
|
||||
_model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = (
|
||||
ModelInputForCPUWithSamplingMetadata)
|
||||
_builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder
|
||||
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
self,
|
||||
tensor_dict: Dict[str, Any],
|
||||
) -> ModelInputForCPUWithSamplingMetadata:
|
||||
return ModelInputForCPUWithSamplingMetadata.from_broadcasted_tensor_dict( # noqa: E501
|
||||
tensor_dict,
|
||||
attn_backend=self.attn_backend,
|
||||
)
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
) -> ModelInputForCPUWithSamplingMetadata:
|
||||
"""Prepare the model input based on a given sequence group, including
|
||||
metadata for the sampling step.
|
||||
|
||||
"""
|
||||
model_input = self._prepare_model_input_tensors(
|
||||
seq_group_metadata_list, finished_requests_ids)
|
||||
# Sampling metadata is only required for the final pp group
|
||||
generators = self.get_generators(finished_requests_ids)
|
||||
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
|
||||
model_input.seq_lens,
|
||||
model_input.query_lens,
|
||||
self.device,
|
||||
pin_memory=False,
|
||||
generators=generators)
|
||||
|
||||
is_prompt = (seq_group_metadata_list[0].is_prompt
|
||||
if seq_group_metadata_list else None)
|
||||
return dataclasses.replace(model_input,
|
||||
sampling_metadata=sampling_metadata,
|
||||
virtual_engine=virtual_engine,
|
||||
is_prompt=is_prompt)
|
||||
|
||||
@torch.no_grad()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: ModelInputForCPUWithSamplingMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
previous_hidden_states: Optional[torch.Tensor] = None,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
if num_steps > 1:
|
||||
raise ValueError(
|
||||
"CPU worker does not support multi-step execution.")
|
||||
|
||||
if self.lora_config:
|
||||
assert model_input.lora_requests is not None
|
||||
assert model_input.lora_mapping is not None
|
||||
self.set_active_loras(model_input.lora_requests,
|
||||
model_input.lora_mapping)
|
||||
|
||||
model_executable = self.model
|
||||
|
||||
multimodal_kwargs = {}
|
||||
if model_input.multi_modal_kwargs is not None:
|
||||
multimodal_kwargs = MultiModalKwargs.as_kwargs(
|
||||
model_input.multi_modal_kwargs,
|
||||
device=self.device,
|
||||
)
|
||||
execute_model_kwargs = {}
|
||||
if previous_hidden_states is not None:
|
||||
execute_model_kwargs.update(
|
||||
{"previous_hidden_states": previous_hidden_states})
|
||||
|
||||
with set_forward_context(model_input.attn_metadata, self.vllm_config,
|
||||
model_input.virtual_engine):
|
||||
hidden_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**execute_model_kwargs,
|
||||
**multimodal_kwargs,
|
||||
)
|
||||
|
||||
# Compute the logits.
|
||||
logits = self.model.compute_logits(hidden_states,
|
||||
model_input.sampling_metadata)
|
||||
|
||||
# Only perform sampling in the driver worker.
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
|
||||
# Sample the next token.
|
||||
output = self.sampler(
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
if self.return_hidden_states:
|
||||
# we only need to pass hidden states of most recent token
|
||||
if model_input.is_prompt:
|
||||
output.prefill_hidden_states = hidden_states
|
||||
output.hidden_states = hidden_states
|
||||
return [output]
|
||||
|
||||
def generate_proposals(self, *args, **kwargs):
|
||||
return self.model.generate_proposals(*args, **kwargs)
|
|
@ -0,0 +1,125 @@
|
|||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.multimodal import MultiModalKwargs
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
|
||||
SequenceGroupMetadata)
|
||||
from vllm.worker.cpu_model_runner import (CPUModelRunnerBase, ModelInputForCPU,
|
||||
ModelInputForCPUBuilder)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class ModelInputForCPUWithPoolingMetadata(ModelInputForCPU):
|
||||
"""
|
||||
Used by the CPUPoolingModelRunner.
|
||||
"""
|
||||
pooling_metadata: Optional["PoolingMetadata"] = None
|
||||
|
||||
|
||||
class CPUPoolingModelRunner(
|
||||
CPUModelRunnerBase[ModelInputForCPUWithPoolingMetadata]):
|
||||
_model_input_cls: Type[ModelInputForCPUWithPoolingMetadata] = (
|
||||
ModelInputForCPUWithPoolingMetadata)
|
||||
_builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: ModelInputForCPUWithPoolingMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]:
|
||||
if num_steps > 1:
|
||||
raise ValueError(
|
||||
"CPU worker does not support multi-step execution.")
|
||||
|
||||
model_executable = self.model
|
||||
cross_enc_kwargs = {}
|
||||
if model_input.token_type_ids is not None:
|
||||
cross_enc_kwargs["token_type_ids"] = model_input.token_type_ids
|
||||
execute_model_kwargs = {
|
||||
"input_ids":
|
||||
model_input.input_tokens,
|
||||
"positions":
|
||||
model_input.input_positions,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
model_input.multi_modal_kwargs or {},
|
||||
device=self.device,
|
||||
),
|
||||
**cross_enc_kwargs,
|
||||
"intermediate_tensors":
|
||||
intermediate_tensors,
|
||||
}
|
||||
|
||||
with set_forward_context(model_input.attn_metadata, self.vllm_config,
|
||||
model_input.virtual_engine):
|
||||
hidden_states = model_executable(**execute_model_kwargs)
|
||||
|
||||
# Only perform pooling in the driver worker.
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
|
||||
return [
|
||||
self.model.pooler(hidden_states=hidden_states,
|
||||
pooling_metadata=model_input.pooling_metadata)
|
||||
]
|
||||
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
self,
|
||||
tensor_dict: Dict[str,
|
||||
Any]) -> ModelInputForCPUWithPoolingMetadata:
|
||||
return ModelInputForCPUWithPoolingMetadata.from_broadcasted_tensor_dict(
|
||||
tensor_dict,
|
||||
attn_backend=self.attn_backend,
|
||||
)
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
) -> ModelInputForCPUWithPoolingMetadata:
|
||||
assert seq_group_metadata_list is not None
|
||||
model_input = self._prepare_model_input_tensors(
|
||||
seq_group_metadata_list, finished_requests_ids)
|
||||
# Prepare PoolingMetadata.
|
||||
assert model_input.seq_lens is not None
|
||||
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
|
||||
model_input.seq_lens)
|
||||
|
||||
return dataclasses.replace(model_input,
|
||||
virtual_engine=virtual_engine,
|
||||
pooling_metadata=pooling_metadata)
|
||||
|
||||
def _prepare_pooling(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
prompt_lens: List[int],
|
||||
) -> PoolingMetadata:
|
||||
"""Prepare PoolingMetadata for the sequence group metadata list."""
|
||||
seq_groups: List[Tuple[List[int], PoolingParams]] = []
|
||||
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
pooling_params = seq_group_metadata.pooling_params
|
||||
seq_groups.append((seq_ids, pooling_params))
|
||||
|
||||
seq_data: Dict[int, SequenceData] = {}
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
seq_data.update(seq_group_metadata.seq_data)
|
||||
|
||||
pooling_metadata = PoolingMetadata(
|
||||
seq_groups=seq_groups,
|
||||
seq_data=seq_data,
|
||||
prompt_lens=prompt_lens,
|
||||
)
|
||||
|
||||
return pooling_metadata
|
|
@ -0,0 +1,452 @@
|
|||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""A CPU worker class."""
|
||||
import os
|
||||
from importlib import util
|
||||
from typing import List, Optional, Set, Tuple, Type
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import get_attn_backend
|
||||
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||
ParallelConfig, VllmConfig)
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import bind_kv_cache
|
||||
from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
|
||||
from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase
|
||||
from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner
|
||||
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
|
||||
WorkerInput)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class CPUCacheEngine:
|
||||
"""Manages the KV cache for CPU backend.
|
||||
|
||||
This class is responsible for initializing and managing CPU KV
|
||||
caches. It also provides methods for performing KV cache operations, such
|
||||
as copying.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
device_config: DeviceConfig) -> None:
|
||||
assert device_config.device_type == "cpu"
|
||||
self.cache_config = cache_config
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
|
||||
self.head_size = model_config.get_head_size()
|
||||
self.num_layers = model_config.get_num_layers(parallel_config)
|
||||
self.num_heads = model_config.get_num_kv_heads(parallel_config)
|
||||
|
||||
self.block_size = cache_config.block_size
|
||||
# Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks
|
||||
# for CPU backend, because we want to reuse KV cache management
|
||||
# in the scheduler.
|
||||
self.num_cpu_blocks = cache_config.num_gpu_blocks
|
||||
|
||||
self.dtype = CPUCacheEngine.get_kv_cache_dtype(cache_config,
|
||||
model_config)
|
||||
|
||||
# Get attention backend.
|
||||
self.attn_backend = get_attn_backend(
|
||||
self.model_config.get_head_size(),
|
||||
self.model_config.dtype,
|
||||
cache_config.cache_dtype,
|
||||
self.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
use_mla=self.model_config.use_mla,
|
||||
)
|
||||
|
||||
# Initialize the cache.
|
||||
self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks)
|
||||
|
||||
def _allocate_kv_cache(
|
||||
self,
|
||||
num_blocks: int,
|
||||
) -> List[torch.Tensor]:
|
||||
"""Allocates KV cache on CPU."""
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks, self.block_size, self.num_heads, self.head_size)
|
||||
kv_cache: List[torch.Tensor] = []
|
||||
for _ in range(self.num_layers):
|
||||
kv_cache.append(
|
||||
torch.empty(kv_cache_shape, dtype=self.dtype, device="cpu"))
|
||||
return kv_cache
|
||||
|
||||
def swap_in(self, src_to_dst: torch.Tensor) -> None:
|
||||
raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
|
||||
|
||||
def swap_out(self, src_to_dst: torch.Tensor) -> None:
|
||||
raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
|
||||
|
||||
def copy(self, src_to_dsts: torch.Tensor) -> None:
|
||||
self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts)
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_dtype(cache_config: CacheConfig,
|
||||
model_config: ModelConfig):
|
||||
if cache_config.cache_dtype == "auto":
|
||||
return model_config.dtype
|
||||
elif cache_config.cache_dtype in ["fp8", "fp8_e5m2"]:
|
||||
return torch.float8_e5m2
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported KV cache type "
|
||||
f"{cache_config.cache_dtype}.")
|
||||
|
||||
@staticmethod
|
||||
def get_cache_block_size(
|
||||
cache_config: CacheConfig,
|
||||
model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
) -> int:
|
||||
head_size = model_config.get_head_size()
|
||||
num_heads = model_config.get_num_kv_heads(parallel_config)
|
||||
num_layers = model_config.get_num_layers(parallel_config)
|
||||
|
||||
key_cache_block = cache_config.block_size * num_heads * head_size
|
||||
value_cache_block = key_cache_block if not model_config.use_mla else 0
|
||||
total = num_layers * (key_cache_block + value_cache_block)
|
||||
dtype = CPUCacheEngine.get_kv_cache_dtype(cache_config, model_config)
|
||||
dtype_size = torch.tensor([], dtype=dtype).element_size()
|
||||
return dtype_size * total
|
||||
|
||||
|
||||
class CPUWorker(LocalOrDistributedWorkerBase):
|
||||
"""A worker class that executes (a partition of) the model on a CPU socket.
|
||||
|
||||
Each worker is associated with a single CPU socket. The worker is
|
||||
responsible for maintaining the KV cache and executing the model on the
|
||||
CPU. In case of distributed inference, each worker is assigned a partition
|
||||
of the model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
model_runner_cls: Optional[Type[CPUModelRunner]] = None,
|
||||
) -> None:
|
||||
WorkerBase.__init__(self, vllm_config=vllm_config)
|
||||
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
vllm_config.parallel_config.rank = rank
|
||||
|
||||
self.distributed_init_method = distributed_init_method
|
||||
|
||||
self.is_driver_worker = is_driver_worker
|
||||
if self.is_driver_worker:
|
||||
assert self.rank == 0, "The driver worker must have rank 0."
|
||||
|
||||
if self.model_config.trust_remote_code:
|
||||
# note: lazy import to avoid importing torch before initializing
|
||||
from vllm.utils import init_cached_hf_modules
|
||||
init_cached_hf_modules()
|
||||
|
||||
# Setup OpenMP threads affinity.
|
||||
omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND
|
||||
self.local_omp_cpuid = "all"
|
||||
if omp_cpuids == "auto":
|
||||
self.local_omp_cpuid = self.get_cpus_id_binding_based_on_numa_nodes(
|
||||
)
|
||||
else:
|
||||
self.local_omp_cpuid = omp_cpuids.split("|")[rank]
|
||||
|
||||
# Return hidden states from target model if the draft model is an
|
||||
# mlp_speculator
|
||||
speculative_config = self.speculative_config
|
||||
model_config = self.model_config
|
||||
speculative_args = {} if speculative_config is None \
|
||||
or (speculative_config.draft_model_config.model ==
|
||||
model_config.model) \
|
||||
or (speculative_config.draft_model_config.hf_config.model_type
|
||||
not in ["medusa", "mlp_speculator", "eagle"]) \
|
||||
else {"return_hidden_states": True}
|
||||
ModelRunnerClass: Type[CPUModelRunnerBase] = CPUModelRunner
|
||||
if self.model_config.runner_type == "pooling":
|
||||
ModelRunnerClass = CPUPoolingModelRunner
|
||||
elif self.model_config.is_encoder_decoder:
|
||||
ModelRunnerClass = CPUEncoderDecoderModelRunner
|
||||
self.model_runner: CPUModelRunnerBase = ModelRunnerClass(
|
||||
vllm_config=vllm_config,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
is_driver_worker=is_driver_worker,
|
||||
**speculative_args,
|
||||
)
|
||||
if model_runner_cls is not None:
|
||||
self.model_runner = model_runner_cls(self.model_runner)
|
||||
# Uninitialized cache engine. Will be initialized by
|
||||
# initialize_cache.
|
||||
self.cache_engine: List[CPUCacheEngine]
|
||||
# Initialize cpu_cache as pooling models don't initialize kv_caches
|
||||
self.cpu_cache: Optional[List[List[torch.Tensor]]] = None
|
||||
|
||||
# Torch profiler. Enabled and configured through env vars:
|
||||
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
||||
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
logger.info("Profiling enabled. Traces will be saved to: %s",
|
||||
torch_profiler_trace_dir)
|
||||
self.profiler = torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
],
|
||||
with_stack=True,
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
torch_profiler_trace_dir, use_gzip=True))
|
||||
else:
|
||||
self.profiler = None
|
||||
|
||||
def start_profile(self):
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
self.profiler.start()
|
||||
|
||||
def stop_profile(self):
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
self.profiler.stop()
|
||||
|
||||
def init_device(self) -> None:
|
||||
if self.local_omp_cpuid != "all":
|
||||
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
|
||||
if ret:
|
||||
logger.info(ret)
|
||||
|
||||
# Note: unique identifier for creating allreduce shared memory
|
||||
os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split(
|
||||
":")[-1]
|
||||
self.device = torch.device("cpu")
|
||||
self.init_distributed_environment()
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
def load_model(self):
|
||||
self.model_runner.load_model()
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Determine the number of blocks available for the KV cache.
|
||||
|
||||
This determines how many KV blocks can fit into the configured CPU
|
||||
KV cache space.
|
||||
|
||||
Note that since vLLM assumes a block resides on GPU if it can be
|
||||
modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0.
|
||||
This allows us to reuse the scheduler of vLLM without generalizing it
|
||||
to different devices.
|
||||
"""
|
||||
# For CPU device, the block number will be calculated based on the
|
||||
# cpu_kvcache_space.
|
||||
cache_block_size = self.get_cache_block_size_bytes()
|
||||
num_cpu_blocks = int(self.cache_config.cpu_kvcache_space_bytes //
|
||||
cache_block_size)
|
||||
num_cpu_blocks = max(num_cpu_blocks, 0)
|
||||
|
||||
# Note: To reuse the cache management procedure,
|
||||
# use cpu cache as 'gpu cache'.
|
||||
num_gpu_blocks = num_cpu_blocks
|
||||
num_cpu_blocks = 0
|
||||
return num_gpu_blocks, num_cpu_blocks
|
||||
|
||||
def initialize_cache(self, num_gpu_blocks: int,
|
||||
num_cpu_blocks: int) -> None:
|
||||
"""Initialize the KV cache. Currently, swappable CPU memory is not
|
||||
supported.
|
||||
|
||||
Since this worker does not support GPUs, we use the num_gpu_blocks to
|
||||
determine how many non-swappable CPU blocks to allocate.
|
||||
"""
|
||||
assert (num_cpu_blocks == 0
|
||||
), f"{type(self)} does not support swappable cache"
|
||||
|
||||
# Note: To reuse the cache management procedure,
|
||||
# use cpu cache as 'gpu cache'.
|
||||
num_cpu_blocks = num_gpu_blocks
|
||||
|
||||
self._validate_num_cpu_blocks(num_cpu_blocks)
|
||||
self.cache_config.num_gpu_blocks = num_cpu_blocks
|
||||
self.cache_config.num_cpu_blocks = 0
|
||||
|
||||
# Initialize the cache.
|
||||
self._init_cache_engine()
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.model_runner.add_lora(lora_request)
|
||||
|
||||
def remove_lora(self, lora_id: int) -> bool:
|
||||
return self.model_runner.remove_lora(lora_id)
|
||||
|
||||
def pin_lora(self, lora_id: int) -> bool:
|
||||
return self.model_runner.pin_lora(lora_id)
|
||||
|
||||
def list_loras(self) -> Set[int]:
|
||||
return self.model_runner.list_loras()
|
||||
|
||||
def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None:
|
||||
"""Raise errors if the num_cpu_blocks is invalid.
|
||||
"""
|
||||
if num_cpu_blocks <= 0:
|
||||
raise ValueError("No available memory for the cache blocks. "
|
||||
"Try increasing `VLLM_CPU_KVCACHE_SPACE` when "
|
||||
"initializing the engine.")
|
||||
|
||||
max_seq_len = self.cache_config.block_size * num_cpu_blocks
|
||||
if self.model_config.max_model_len > max_seq_len:
|
||||
raise ValueError(
|
||||
f"The model's max seq len ({self.model_config.max_model_len}) "
|
||||
"is larger than the maximum number of tokens that can be "
|
||||
f"stored in KV cache ({max_seq_len}). Try increasing "
|
||||
"`VLLM_CPU_KVCACHE_SPACE` or decreasing `max_model_len` when "
|
||||
"initializing the engine.")
|
||||
|
||||
def _init_cache_engine(self) -> None:
|
||||
self.cache_engine = [
|
||||
CPUCacheEngine(self.cache_config, self.model_config,
|
||||
self.parallel_config, self.device_config)
|
||||
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
self.cpu_cache = [
|
||||
self.cache_engine[ve].cpu_cache
|
||||
for ve in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
bind_kv_cache(self.compilation_config.static_forward_context,
|
||||
self.cpu_cache)
|
||||
self.model_runner.block_size = self.cache_engine[0].block_size
|
||||
|
||||
assert all(
|
||||
self.cpu_cache[ve] is not None
|
||||
for ve in range(self.parallel_config.pipeline_parallel_size))
|
||||
|
||||
# Populate the cache to warmup the memory
|
||||
for ve in range(self.parallel_config.pipeline_parallel_size):
|
||||
for layer_cache in self.cpu_cache[ve]:
|
||||
layer_cache.fill_(0)
|
||||
|
||||
@property
|
||||
def do_metadata_broadcast(self) -> bool:
|
||||
return self.parallel_config.tensor_parallel_size > 1
|
||||
|
||||
@property
|
||||
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
|
||||
return self.cpu_cache
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self.model_runner.vocab_size
|
||||
|
||||
@property
|
||||
def max_model_len(self) -> int:
|
||||
return self.model_config.max_model_len
|
||||
|
||||
def execute_worker(
|
||||
self,
|
||||
worker_input: WorkerInput,
|
||||
) -> None:
|
||||
if (worker_input.blocks_to_copy is not None
|
||||
and worker_input.blocks_to_copy.numel() > 0):
|
||||
self.cache_engine[worker_input.virtual_engine].copy(
|
||||
worker_input.blocks_to_copy)
|
||||
|
||||
@torch.inference_mode()
|
||||
def prepare_worker_input(
|
||||
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
|
||||
assert execute_model_req is not None
|
||||
virtual_engine: int = execute_model_req.virtual_engine
|
||||
num_seq_groups: int = len(execute_model_req.seq_group_metadata_list)
|
||||
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
|
||||
device="cpu",
|
||||
dtype=torch.int64).view(-1, 2)
|
||||
assert len(execute_model_req.blocks_to_swap_in) == 0
|
||||
assert len(execute_model_req.blocks_to_swap_out) == 0
|
||||
return WorkerInput(
|
||||
num_seq_groups=num_seq_groups,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
virtual_engine=virtual_engine,
|
||||
)
|
||||
|
||||
def init_distributed_environment(self) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
|
||||
parallel_config = self.parallel_config
|
||||
rank = self.rank
|
||||
distributed_init_method = self.distributed_init_method
|
||||
init_distributed_environment(
|
||||
world_size=parallel_config.world_size,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
backend="gloo",
|
||||
)
|
||||
|
||||
# A small all_reduce for warmup.
|
||||
torch.distributed.all_reduce(torch.zeros(1).cpu())
|
||||
|
||||
ensure_model_parallel_initialized(
|
||||
parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
|
||||
def get_cache_block_size_bytes(self) -> int:
|
||||
"""Return the size in bytes of a single KV cache block.
|
||||
"""
|
||||
return CPUCacheEngine.get_cache_block_size(self.cache_config,
|
||||
self.model_config,
|
||||
self.parallel_config)
|
||||
|
||||
def get_cpus_id_binding_based_on_numa_nodes(self) -> str:
|
||||
"""Return CPUs id binding based on NUMA nodes.
|
||||
"""
|
||||
rank_to_cpus = self.local_omp_cpuid
|
||||
# Setup OpenMP thread affinity based on NUMA nodes automatically
|
||||
world_size = self.vllm_config.parallel_config.world_size
|
||||
libnuma_found = util.find_spec("numa") is not None
|
||||
psutil_found = util.find_spec("psutil") is not None
|
||||
if libnuma_found and psutil_found:
|
||||
import psutil
|
||||
from numa import info
|
||||
cpu_count = psutil.cpu_count(logical=False)
|
||||
cpus_allow_list = psutil.Process().cpu_affinity()
|
||||
numa_size = info.get_num_configured_nodes()
|
||||
cpu_count_per_numa = cpu_count // numa_size
|
||||
num_of_reserved_cpu = min(envs.VLLM_CPU_NUM_OF_RESERVED_CPU,
|
||||
cpu_count_per_numa // 2)
|
||||
|
||||
# check allow node_to_cpus list
|
||||
node_to_cpus = []
|
||||
for i in range(numa_size):
|
||||
node_intersect = set(
|
||||
info.node_to_cpus(i)).intersection(cpus_allow_list)
|
||||
if bool(node_intersect):
|
||||
node_to_cpus.append(list(node_intersect))
|
||||
|
||||
if world_size > len(node_to_cpus):
|
||||
logger.error(
|
||||
"Auto thread-binding failed due to "
|
||||
"world size: %d is larger than "
|
||||
"allowed NUMA nodes number: %d."
|
||||
"Please try to bind threads manually.", world_size,
|
||||
len(node_to_cpus))
|
||||
else:
|
||||
end = cpu_count_per_numa - num_of_reserved_cpu
|
||||
rank_to_cpus_list = node_to_cpus[self.rank][:end]
|
||||
rank_to_cpus = ','.join(str(x) for x in rank_to_cpus_list)
|
||||
logger.info("auto thread-binding list: %s", rank_to_cpus)
|
||||
else:
|
||||
logger.warning(
|
||||
"Auto thread-binding is not supported due to "
|
||||
"the lack of package numa and psutil,"
|
||||
"fallback to no thread-binding. To get better performance,"
|
||||
"please try to manually bind threads.")
|
||||
return rank_to_cpus
|
|
@ -0,0 +1,108 @@
|
|||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.distributed import broadcast_tensor_dict
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.worker.tpu_model_runner import ModelInputForTPU
|
||||
from vllm.worker.tpu_worker import TPUWorker
|
||||
from vllm.worker.worker_base import WorkerInput
|
||||
|
||||
|
||||
class MultiStepTPUWorker(TPUWorker):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.cached_model_input: Optional[ModelInputForTPU] = None
|
||||
|
||||
def _get_driver_input_and_broadcast(
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
) -> Tuple[ModelInputForTPU, WorkerInput, Dict[str, torch.Tensor]]:
|
||||
assert self.is_driver_worker
|
||||
assert execute_model_req.virtual_engine == 0
|
||||
|
||||
is_first_multi_step = execute_model_req.is_first_multi_step
|
||||
is_last_step = execute_model_req.is_last_step
|
||||
if is_first_multi_step:
|
||||
worker_input: WorkerInput = self.prepare_worker_input(
|
||||
execute_model_req=execute_model_req)
|
||||
worker_input = dataclasses.replace(
|
||||
worker_input,
|
||||
num_steps=execute_model_req.num_lookahead_slots + 1)
|
||||
model_input: ModelInputForTPU = (
|
||||
self.model_runner.prepare_model_input(
|
||||
execute_model_req.seq_group_metadata_list,
|
||||
execute_model_req.virtual_engine,
|
||||
execute_model_req.finished_requests_ids))
|
||||
|
||||
if execute_model_req.async_callback:
|
||||
model_input = dataclasses.replace(
|
||||
model_input,
|
||||
async_callback=execute_model_req.async_callback)
|
||||
else:
|
||||
assert self.cached_model_input is not None
|
||||
model_input = self.cached_model_input
|
||||
worker_input = WorkerInput()
|
||||
model_input = dataclasses.replace(
|
||||
model_input,
|
||||
is_first_multi_step=is_first_multi_step,
|
||||
is_last_step=is_last_step)
|
||||
|
||||
if self.do_metadata_broadcast:
|
||||
if is_first_multi_step:
|
||||
broadcast_data = worker_input.as_broadcastable_tensor_dict()
|
||||
broadcast_data.update(
|
||||
model_input.as_broadcastable_tensor_dict())
|
||||
broadcast_tensor_dict(broadcast_data, src=0)
|
||||
else:
|
||||
broadcast_data = {
|
||||
"is_first_multi_step": is_first_multi_step,
|
||||
"is_last_step": is_last_step,
|
||||
}
|
||||
broadcast_tensor_dict(broadcast_data, src=0)
|
||||
|
||||
# Retuning empty dict here to keep this compatible with
|
||||
# `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
|
||||
return model_input, worker_input, {}
|
||||
|
||||
def prepare_input(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None,
|
||||
) -> Optional[Tuple[ModelInputForTPU, WorkerInput, Dict[str,
|
||||
torch.Tensor]]]:
|
||||
if self.is_driver_worker:
|
||||
if execute_model_req is None:
|
||||
if self.do_metadata_broadcast:
|
||||
broadcast_tensor_dict({}, src=0)
|
||||
return None
|
||||
|
||||
model_input, worker_input, _ = self._get_driver_input_and_broadcast(
|
||||
execute_model_req)
|
||||
if model_input.is_first_multi_step:
|
||||
self.cached_model_input = model_input
|
||||
return model_input, worker_input, {}
|
||||
else:
|
||||
broadcast_data = broadcast_tensor_dict(src=0)
|
||||
if not broadcast_data:
|
||||
return None
|
||||
|
||||
if len(broadcast_data) == 2:
|
||||
assert self.cached_model_input is not None
|
||||
self.cached_model_input = dataclasses.replace(
|
||||
self.cached_model_input,
|
||||
is_first_multi_step=broadcast_data["is_first_multi_step"],
|
||||
is_last_step=broadcast_data["is_last_step"])
|
||||
empty_worker_input = WorkerInput()
|
||||
return self.cached_model_input, empty_worker_input, {}
|
||||
|
||||
worker_input = WorkerInput.from_broadcasted_tensor_dict(
|
||||
broadcast_data)
|
||||
model_input = (
|
||||
self.model_runner.
|
||||
make_model_input_from_broadcasted_tensor_dict(broadcast_data))
|
||||
self.cached_model_input = model_input
|
||||
return model_input, worker_input, {}
|
|
@ -0,0 +1,909 @@
|
|||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import enum
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
|
||||
Type, Union)
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.runtime as xr
|
||||
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.forward_context import get_forward_context, set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
|
||||
Logprob, SequenceGroupMetadata, SequenceOutput)
|
||||
from vllm.worker.model_runner_base import (
|
||||
ModelRunnerBase, ModelRunnerInputBase,
|
||||
_add_attn_metadata_broadcastable_dict,
|
||||
_init_attn_metadata_from_tensor_dict)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Here we utilize the behavior that out-of-bound index is ignored.
|
||||
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
|
||||
_PAD_SLOT_ID = 1_000_000_000
|
||||
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
|
||||
_ENABLE_TOP_P = False
|
||||
# FIXME(woosuk): A temporary hack to support `n > 1`.
|
||||
# This can significantly affect the performance if too large.
|
||||
_MAX_NUM_SAMPLES = 128
|
||||
|
||||
|
||||
class ExecutionMode(enum.Enum):
|
||||
PREFILL = enum.auto()
|
||||
DECODE = enum.auto()
|
||||
PREFIX_PREFILL = enum.auto()
|
||||
|
||||
def is_prefill(self) -> bool:
|
||||
return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelInputForTPU(ModelRunnerInputBase):
|
||||
token_ids: torch.Tensor
|
||||
position_ids: torch.Tensor
|
||||
attn_metadata: AttentionMetadata
|
||||
input_lens: torch.Tensor
|
||||
t: torch.Tensor
|
||||
p: torch.Tensor
|
||||
num_samples: int
|
||||
n: List[int]
|
||||
seq_groups: List[List[int]]
|
||||
is_first_multi_step: bool = True
|
||||
is_last_step: bool = True
|
||||
virtual_engine: int = 0
|
||||
async_callback: Optional[Callable] = None
|
||||
|
||||
def as_broadcastable_tensor_dict(
|
||||
self) -> Dict[str, Union[int, torch.Tensor]]:
|
||||
tensor_dict = {
|
||||
"token_ids": self.token_ids,
|
||||
"position_ids": self.position_ids,
|
||||
"input_lens": self.input_lens,
|
||||
"t": self.t,
|
||||
"p": self.p,
|
||||
"num_samples": self.num_samples,
|
||||
"n": self.n,
|
||||
"seq_groups": self.seq_groups,
|
||||
"is_first_multi_step": self.is_first_multi_step,
|
||||
"is_last_step": self.is_last_step,
|
||||
"virtual_engine": self.virtual_engine,
|
||||
}
|
||||
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||
return tensor_dict
|
||||
|
||||
@classmethod
|
||||
def from_broadcasted_tensor_dict(
|
||||
cls: Type["ModelInputForTPU"],
|
||||
tensor_dict: Dict[str, Any],
|
||||
attn_backend: Optional["AttentionBackend"] = None,
|
||||
) -> "ModelInputForTPU":
|
||||
if attn_backend is not None:
|
||||
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
||||
attn_backend, tensor_dict)
|
||||
return cls(**tensor_dict)
|
||||
|
||||
|
||||
class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
is_driver_worker: bool = False,
|
||||
):
|
||||
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
|
||||
self.is_driver_worker = is_driver_worker
|
||||
|
||||
self.block_size = self.cache_config.block_size
|
||||
self.max_num_blocks_per_seq = (self.model_config.max_model_len //
|
||||
self.block_size)
|
||||
self.block_tables = np.zeros(
|
||||
(self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq),
|
||||
dtype=np.int32)
|
||||
self.attn_backend = get_attn_backend(
|
||||
self.model_config.get_head_size(),
|
||||
self.model_config.dtype,
|
||||
self.cache_config.cache_dtype,
|
||||
self.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
False,
|
||||
)
|
||||
self.cached_step_outputs: List[torch.Tensor] = []
|
||||
|
||||
smem_size = 512 * 1024
|
||||
block_table_size = 4 * self.block_tables.size
|
||||
if block_table_size >= smem_size:
|
||||
logger.warning(
|
||||
"The max_model_len (%d) is too large. This may degrade the "
|
||||
"performance due to the insufficient smem size. Consider "
|
||||
"setting --max-model-len to a smaller value, like %d.",
|
||||
self.model_config.max_model_len,
|
||||
self.model_config.max_model_len /
|
||||
(block_table_size / smem_size))
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.device = self.device_config.device
|
||||
|
||||
# NOTE(woosuk): While the executor assigns the TP ranks to the worker
|
||||
# process, the ranks can be different from the ranks internally assigned
|
||||
# by the xm runtime. Therefore, there is a mismatch in the rank
|
||||
# assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
|
||||
# This is not a problem in linear layers because all-reduce is
|
||||
# rank-agnostic. However, it matters for all-gather as the ranks
|
||||
# determine the order of concatenating the output tensors.
|
||||
# As a workaround, we use the xm's rank assignment only when loading
|
||||
# the embedding weights.
|
||||
xm_tp_rank = xr.global_ordinal()
|
||||
with patch(
|
||||
"vllm.model_executor.layers.vocab_parallel_embedding."
|
||||
"get_tensor_model_parallel_rank",
|
||||
return_value=xm_tp_rank):
|
||||
model = get_model(vllm_config=self.vllm_config)
|
||||
model = model.eval()
|
||||
xm.wait_device_ops()
|
||||
model = ModelWrapper(model)
|
||||
self.model = torch.compile(model,
|
||||
backend="openxla",
|
||||
fullgraph=True,
|
||||
dynamic=False)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model.model
|
||||
|
||||
def _dummy_run(
|
||||
self,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
exec_mode: ExecutionMode,
|
||||
) -> None:
|
||||
exec_mode = ExecutionMode(exec_mode)
|
||||
if exec_mode.is_prefill():
|
||||
seq_len = (seq_len + 15) // 16 * 16
|
||||
token_ids = torch.zeros((batch_size, seq_len),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
position_ids = torch.zeros((batch_size, seq_len),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
slot_mapping = torch.zeros((batch_size, seq_len),
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
input_lens = torch.ones((batch_size, ),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
if exec_mode == ExecutionMode.PREFILL:
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
num_prefills=batch_size,
|
||||
num_prefill_tokens=batch_size * seq_len,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=False,
|
||||
block_tables=None,
|
||||
context_lens=None,
|
||||
effective_query_lens=None,
|
||||
)
|
||||
else:
|
||||
context_lens = torch.ones((batch_size, ),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
block_tables = torch.tensor(self.block_tables[:batch_size],
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
effective_query_lens = torch.ones_like(context_lens)
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
num_prefills=batch_size,
|
||||
num_prefill_tokens=batch_size * seq_len,
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=False,
|
||||
block_tables=block_tables,
|
||||
context_lens=context_lens,
|
||||
effective_query_lens=effective_query_lens,
|
||||
)
|
||||
else:
|
||||
assert seq_len == 1
|
||||
token_ids = torch.zeros((batch_size, seq_len),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
position_ids = torch.zeros((batch_size, seq_len),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
slot_mapping = torch.zeros((batch_size, seq_len),
|
||||
dtype=torch.int64,
|
||||
device=self.device)
|
||||
block_tables = torch.zeros(
|
||||
(batch_size, self.max_num_blocks_per_seq),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
context_lens = torch.ones((batch_size, ),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
input_lens = torch.ones((batch_size, ),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=batch_size * seq_len,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=False,
|
||||
block_tables=block_tables,
|
||||
context_lens=context_lens,
|
||||
)
|
||||
t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
|
||||
p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
|
||||
num_samples = _MAX_NUM_SAMPLES if exec_mode.is_prefill() else 1
|
||||
|
||||
# NOTE(woosuk): There are two stages of compilation: torch.compile and
|
||||
# XLA compilation. Using `mark_dynamic` can reduce the torch.compile
|
||||
# overhead by reusing the FX graph for different shapes.
|
||||
# However, the XLA graph will still require static shapes and needs to
|
||||
# be re-compiled for every different shapes. This overhead is inevitable
|
||||
# in the first run, but can be skipped afterwards as we cache the XLA
|
||||
# graphs in the disk (VLLM_XLA_CACHE_PATH).
|
||||
if exec_mode.is_prefill():
|
||||
# Prefll
|
||||
torch._dynamo.mark_dynamic(token_ids, 1)
|
||||
torch._dynamo.mark_dynamic(position_ids, 1)
|
||||
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1)
|
||||
else:
|
||||
# Decode
|
||||
torch._dynamo.mark_dynamic(token_ids, 0)
|
||||
torch._dynamo.mark_dynamic(position_ids, 0)
|
||||
torch._dynamo.mark_dynamic(input_lens, 0)
|
||||
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
|
||||
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
|
||||
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
|
||||
torch._dynamo.mark_dynamic(t, 0)
|
||||
torch._dynamo.mark_dynamic(p, 0)
|
||||
# Dummy run.
|
||||
with set_forward_context(attn_metadata, self.vllm_config, 0):
|
||||
self.model(token_ids, position_ids, input_lens, t, p, num_samples,
|
||||
kv_caches)
|
||||
|
||||
def warmup_model(
|
||||
self,
|
||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
) -> None:
|
||||
# Prefill
|
||||
logger.info("Compiling the model with different input shapes...")
|
||||
start = time.time()
|
||||
for batch_size in [1]:
|
||||
seq_len = 16
|
||||
while seq_len <= self.model_config.max_model_len:
|
||||
self._dummy_run(batch_size,
|
||||
seq_len,
|
||||
kv_caches,
|
||||
exec_mode=ExecutionMode.PREFILL)
|
||||
xm.wait_device_ops()
|
||||
logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len)
|
||||
num_tokens = batch_size * seq_len
|
||||
if num_tokens >= self.scheduler_config.max_num_batched_tokens:
|
||||
break
|
||||
seq_len = seq_len * 2
|
||||
|
||||
end = time.time()
|
||||
logger.info("Compilation for prefill done in %.2f s.", end - start)
|
||||
|
||||
# Prefix prefill
|
||||
if self.cache_config.enable_prefix_caching:
|
||||
logger.info("Compiling the model with different input shapes for "
|
||||
"prefix prefill...")
|
||||
start = time.time()
|
||||
for batch_size in [1]:
|
||||
seq_len = 16
|
||||
while seq_len <= self.model_config.max_model_len:
|
||||
self._dummy_run(batch_size,
|
||||
seq_len,
|
||||
kv_caches,
|
||||
exec_mode=ExecutionMode.PREFIX_PREFILL)
|
||||
xm.wait_device_ops()
|
||||
logger.info("batch_size: %d, seq_len: %d", batch_size,
|
||||
seq_len)
|
||||
num_tokens = batch_size * seq_len
|
||||
if (num_tokens
|
||||
>= self.scheduler_config.max_num_batched_tokens):
|
||||
break
|
||||
seq_len = seq_len * 2
|
||||
end = time.time()
|
||||
logger.info("Compilation for prefix prefill done in %.2f s.",
|
||||
end - start)
|
||||
|
||||
# Decode
|
||||
start = time.time()
|
||||
seq_len = 1
|
||||
batch_size = 8 # Must be in sync with _get_padded_batch_size()
|
||||
while True:
|
||||
self._dummy_run(batch_size,
|
||||
seq_len,
|
||||
kv_caches,
|
||||
exec_mode=ExecutionMode.DECODE)
|
||||
xm.wait_device_ops()
|
||||
logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len)
|
||||
|
||||
if batch_size >= self.scheduler_config.max_num_seqs:
|
||||
break
|
||||
batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2
|
||||
|
||||
end = time.time()
|
||||
logger.info("Compilation for decode done in %.2f s.", end - start)
|
||||
|
||||
def _prepare_prompt(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[int] = []
|
||||
input_positions: List[int] = []
|
||||
prompt_lens: List[int] = []
|
||||
context_lens: List[int] = []
|
||||
slot_mapping: List[int] = []
|
||||
|
||||
for batch_idx, seq_group_metadata in enumerate(
|
||||
seq_group_metadata_list):
|
||||
assert seq_group_metadata.is_prompt
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
assert len(seq_ids) == 1
|
||||
seq_id = seq_ids[0]
|
||||
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
# Could include output tokens when a request is preempted.
|
||||
prompt_tokens = seq_data.get_token_ids()
|
||||
seq_len = len(prompt_tokens)
|
||||
|
||||
num_computed_blocks = len(seq_group_metadata.computed_block_nums)
|
||||
num_computed_tokens = num_computed_blocks * self.block_size
|
||||
if num_computed_tokens > 0:
|
||||
prompt_tokens = prompt_tokens[num_computed_tokens:]
|
||||
context_lens.append(seq_len)
|
||||
else:
|
||||
context_lens.append(0)
|
||||
|
||||
prompt_len = len(prompt_tokens)
|
||||
prompt_lens.append(prompt_len)
|
||||
|
||||
input_tokens.extend(prompt_tokens)
|
||||
input_positions.extend(range(num_computed_tokens, seq_len))
|
||||
|
||||
assert seq_group_metadata.block_tables is not None
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
for i in range(num_computed_tokens, seq_len):
|
||||
block_number = block_table[i // self.block_size]
|
||||
block_offset = i % self.block_size
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
if num_computed_tokens > 0:
|
||||
self.block_tables[batch_idx, :len(block_table)] = block_table
|
||||
|
||||
# Add paddings to EACH prompt to the smallest power of 2 that is
|
||||
# greater than or equal to the prompt length.
|
||||
# We pad the seq_len to reduce the compilation overhead.
|
||||
# We execute each prompt individually (i.e., with batch_size 1)
|
||||
# because the FlashAttention kernel does not support ragged inputs.
|
||||
# TODO(woosuk): Use SplashAttention to support ragged inputs.
|
||||
padded_prompt_len = _get_padded_prefill_len(prompt_len)
|
||||
num_paddings = padded_prompt_len - prompt_len
|
||||
input_tokens += [0] * num_paddings
|
||||
input_positions += [0] * num_paddings
|
||||
slot_mapping += [_PAD_SLOT_ID] * num_paddings
|
||||
|
||||
assert len(prompt_lens) > 0
|
||||
num_prefills = len(prompt_lens)
|
||||
input_tokens = torch.tensor(input_tokens,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
input_positions = torch.tensor(input_positions,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
slot_mapping = torch.tensor(slot_mapping,
|
||||
dtype=torch.int64,
|
||||
device="cpu")
|
||||
prompt_lens = torch.tensor(prompt_lens,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
context_lens = torch.tensor(context_lens,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
block_tables = torch.tensor(self.block_tables[:num_prefills],
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=0, # NOTE: This is not used.
|
||||
num_decode_tokens=0,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=False,
|
||||
block_tables=block_tables,
|
||||
context_lens=context_lens,
|
||||
effective_query_lens=prompt_lens,
|
||||
)
|
||||
return input_tokens, input_positions, attn_metadata, prompt_lens
|
||||
|
||||
def _prepare_decode(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[List[int]] = []
|
||||
input_positions: List[List[int]] = []
|
||||
slot_mapping: List[List[int]] = []
|
||||
context_lens: List[int] = []
|
||||
|
||||
batch_idx = 0
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert not seq_group_metadata.is_prompt
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
for seq_id in seq_ids:
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
generation_token = seq_data.get_last_token_id()
|
||||
input_tokens.append([generation_token])
|
||||
|
||||
seq_len = seq_data.get_len()
|
||||
position = seq_len - 1
|
||||
input_positions.append([position])
|
||||
context_lens.append(seq_len)
|
||||
|
||||
assert seq_group_metadata.block_tables is not None
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
self.block_tables[batch_idx, :len(block_table)] = block_table
|
||||
batch_idx += 1
|
||||
|
||||
block_number = block_table[position // self.block_size]
|
||||
block_offset = position % self.block_size
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append([slot])
|
||||
|
||||
batch_size = _get_padded_batch_size(batch_idx)
|
||||
num_paddings = batch_size - batch_idx
|
||||
input_tokens = input_tokens + [[0]] * num_paddings
|
||||
input_positions = input_positions + [[0]] * num_paddings
|
||||
slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings
|
||||
context_lens = context_lens + [0] * num_paddings
|
||||
|
||||
input_tokens = torch.tensor(input_tokens,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
input_positions = torch.tensor(input_positions,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
slot_mapping = torch.tensor(slot_mapping,
|
||||
dtype=torch.int64,
|
||||
device="cpu")
|
||||
context_lens = torch.tensor(context_lens,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
block_tables = torch.tensor(self.block_tables[:batch_size],
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
input_lens = torch.tensor([1] * batch_size,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
num_prefills=0,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=batch_size,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=False,
|
||||
block_tables=block_tables,
|
||||
context_lens=context_lens,
|
||||
)
|
||||
return input_tokens, input_positions, attn_metadata, input_lens
|
||||
|
||||
def _prepare_sample(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
padded_batch_size: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
t = []
|
||||
p = []
|
||||
n = []
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
sampling_params = seq_group_metadata.sampling_params
|
||||
t.append(sampling_params.temperature)
|
||||
if sampling_params.top_p != 1 and not _ENABLE_TOP_P:
|
||||
raise NotImplementedError(
|
||||
"Top-p sampling is currently disabled for the TPU backend "
|
||||
"due to performance issues.")
|
||||
p.append(sampling_params.top_p)
|
||||
if sampling_params.top_k > 0:
|
||||
raise NotImplementedError(
|
||||
"Top-k sampling is currently disabled for the TPU backend "
|
||||
"due to performance issues.")
|
||||
if sampling_params.n > _MAX_NUM_SAMPLES:
|
||||
raise NotImplementedError(
|
||||
f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU "
|
||||
"backend.")
|
||||
n.append(sampling_params.n)
|
||||
if sampling_params.logprobs is not None:
|
||||
raise NotImplementedError(
|
||||
"logprobs is not currently supported by the TPU backend.")
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
raise NotImplementedError(
|
||||
"prompt_logprobs is not currently supported by the TPU "
|
||||
"backend.")
|
||||
|
||||
# Repeat the sampling params if the seq group has multiple seqs.
|
||||
num_seqs = len(seq_group_metadata.seq_data)
|
||||
t += [t[-1]] * (num_seqs - 1)
|
||||
p += [p[-1]] * (num_seqs - 1)
|
||||
n += [n[-1]] * (num_seqs - 1)
|
||||
|
||||
num_paddings = padded_batch_size - len(t)
|
||||
t += [1.0] * num_paddings
|
||||
p += [1.0] * num_paddings
|
||||
|
||||
t = torch.tensor(t, dtype=torch.float32, device="cpu")
|
||||
p = torch.tensor(p, dtype=torch.float32, device="cpu")
|
||||
return t, p, n
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None,
|
||||
) -> ModelInputForTPU:
|
||||
del finished_requests_ids # Unused.
|
||||
assert virtual_engine == 0
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
# NOTE: We assume that all sequences in the group are all prompts or
|
||||
# all decodes.
|
||||
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||
if is_prompt:
|
||||
inputs = self._prepare_prompt(seq_group_metadata_list)
|
||||
else:
|
||||
inputs = self._prepare_decode(seq_group_metadata_list)
|
||||
input_tokens, input_positions, attn_metadata, input_lens = inputs
|
||||
padded_batch_size = input_tokens.shape[0]
|
||||
t, p, n = self._prepare_sample(seq_group_metadata_list,
|
||||
padded_batch_size)
|
||||
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
|
||||
|
||||
seq_groups = [
|
||||
list(metadata.seq_data.keys())
|
||||
for metadata in seq_group_metadata_list
|
||||
]
|
||||
return ModelInputForTPU(input_tokens, input_positions, attn_metadata,
|
||||
input_lens, t, p, num_samples, n, seq_groups)
|
||||
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU:
|
||||
model_input = ModelInputForTPU.from_broadcasted_tensor_dict(
|
||||
tensor_dict, attn_backend=self.attn_backend)
|
||||
return model_input
|
||||
|
||||
@torch.no_grad()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: ModelInputForTPU,
|
||||
kv_caches: Optional[List[Any]],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
) -> List[SamplerOutput]:
|
||||
assert intermediate_tensors is None
|
||||
if not model_input.is_first_multi_step:
|
||||
if not model_input.is_last_step:
|
||||
return []
|
||||
|
||||
use_async_out_proc = model_input.async_callback is not None
|
||||
sampler_outputs = []
|
||||
num_outputs = len(self.cached_step_outputs)
|
||||
for i in range(num_outputs):
|
||||
next_token_ids = self.cached_step_outputs.pop(0)
|
||||
next_token_ids = next_token_ids.cpu().tolist()
|
||||
sampler_output = _make_decode_output(next_token_ids,
|
||||
model_input.seq_groups)
|
||||
sampler_outputs.append(sampler_output)
|
||||
|
||||
if i < num_outputs - 1 and use_async_out_proc:
|
||||
assert model_input.async_callback is not None
|
||||
ctx = model_input.async_callback.keywords[ # type: ignore
|
||||
"ctx"]
|
||||
ctx.append_output(
|
||||
outputs=[sampler_output],
|
||||
seq_group_metadata_list=ctx.seq_group_metadata_list,
|
||||
scheduler_outputs=ctx.scheduler_outputs,
|
||||
is_async=False,
|
||||
is_last_step=False,
|
||||
is_first_step_output=i == 0)
|
||||
model_input.async_callback()
|
||||
if use_async_out_proc:
|
||||
return [sampler_outputs[-1]]
|
||||
else:
|
||||
return sampler_outputs
|
||||
|
||||
is_prompt = model_input.attn_metadata.num_prefills > 0
|
||||
if is_prompt:
|
||||
assert num_steps == 1
|
||||
# NOTE(woosuk): Since the FlashAttention kernel does not support
|
||||
# ragged inputs, we split the prompts into different batches and
|
||||
# process them separately. This is a temporary hack that should be
|
||||
# optimized by using SplashAttention.
|
||||
orig_slot_mapping = model_input.attn_metadata.slot_mapping
|
||||
orig_block_tables = model_input.attn_metadata.block_tables
|
||||
orig_context_lens = model_input.attn_metadata.context_lens
|
||||
orig_effective_query_lens = \
|
||||
model_input.attn_metadata.effective_query_lens
|
||||
batch_size = model_input.input_lens.shape[0]
|
||||
start_idx = 0
|
||||
next_token_ids = []
|
||||
for i in range(batch_size):
|
||||
# Get the actual prefill_len.
|
||||
prefill_len = model_input.input_lens[i:i + 1].item()
|
||||
prefill_len = _get_padded_prefill_len(prefill_len)
|
||||
end_idx = start_idx + prefill_len
|
||||
|
||||
token_ids = model_input.token_ids[None, start_idx:end_idx].to(
|
||||
self.device)
|
||||
position_ids = model_input.position_ids[None,
|
||||
start_idx:end_idx].to(
|
||||
self.device)
|
||||
attn_metadata = model_input.attn_metadata
|
||||
attn_metadata.num_prefills = 1
|
||||
attn_metadata.slot_mapping = orig_slot_mapping[
|
||||
None, start_idx:end_idx].to(self.device)
|
||||
if orig_context_lens[i].item() > 0:
|
||||
attn_metadata.context_lens = orig_context_lens[i:i + 1].to(
|
||||
self.device)
|
||||
attn_metadata.block_tables = orig_block_tables[
|
||||
i].unsqueeze(0).to(self.device)
|
||||
attn_metadata.effective_query_lens = \
|
||||
orig_effective_query_lens[i:i + 1].to(self.device)
|
||||
else:
|
||||
attn_metadata.context_lens = None
|
||||
attn_metadata.block_tables = None
|
||||
attn_metadata.effective_query_lens = None
|
||||
input_lens = model_input.input_lens[i:i + 1].to(self.device)
|
||||
t = model_input.t[i:i + 1].to(self.device)
|
||||
p = model_input.p[i:i + 1].to(self.device)
|
||||
with set_forward_context(model_input.attn_metadata,
|
||||
self.vllm_config,
|
||||
model_input.virtual_engine):
|
||||
output_token_ids = self.model(token_ids, position_ids,
|
||||
input_lens, t, p,
|
||||
model_input.num_samples,
|
||||
kv_caches)
|
||||
next_token_ids.append(output_token_ids[0])
|
||||
start_idx = end_idx
|
||||
|
||||
if model_input.async_callback is not None:
|
||||
model_input.async_callback()
|
||||
# Retrieve the outputs to CPU.
|
||||
next_token_ids = [
|
||||
output_token_ids.cpu().tolist()
|
||||
for output_token_ids in next_token_ids
|
||||
]
|
||||
|
||||
# NOTE(woosuk): Minimal code to construct the sampler outputs.
|
||||
# The TPU backend does not reuse the sampler, since the TPU backend
|
||||
# does not support advanced sampling parameters such as logprobs.
|
||||
zero_logprob = Logprob(0.0)
|
||||
sampler_outputs = []
|
||||
for i, seq_group in enumerate(model_input.seq_groups):
|
||||
seq_ids = seq_group
|
||||
assert len(seq_ids) == 1
|
||||
seq_id = seq_ids[0]
|
||||
seq_outputs = []
|
||||
for j in range(model_input.n[i]):
|
||||
next_token_id = next_token_ids[i][j]
|
||||
seq_outputs.append(
|
||||
SequenceOutput(seq_id, next_token_id,
|
||||
{next_token_id: zero_logprob}))
|
||||
sampler_outputs.append(
|
||||
CompletionSequenceGroupOutput(seq_outputs, None))
|
||||
return [SamplerOutput(sampler_outputs)]
|
||||
else:
|
||||
token_ids = model_input.token_ids.to(self.device)
|
||||
position_ids = model_input.position_ids.to(self.device)
|
||||
attn_metadata = model_input.attn_metadata
|
||||
attn_metadata.slot_mapping = attn_metadata.slot_mapping.to(
|
||||
self.device)
|
||||
attn_metadata.block_tables = attn_metadata.block_tables.to(
|
||||
self.device)
|
||||
attn_metadata.context_lens = attn_metadata.context_lens.to(
|
||||
self.device)
|
||||
t = model_input.t.to(self.device)
|
||||
p = model_input.p.to(self.device)
|
||||
input_lens = model_input.input_lens.to(self.device)
|
||||
for i in range(num_steps):
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
with set_forward_context(model_input.attn_metadata,
|
||||
self.vllm_config,
|
||||
model_input.virtual_engine):
|
||||
output_token_ids = self.model(token_ids, position_ids,
|
||||
input_lens, t, p,
|
||||
model_input.num_samples,
|
||||
kv_caches)
|
||||
self.cached_step_outputs.append(output_token_ids)
|
||||
|
||||
if i < num_steps - 1:
|
||||
# Prepare the inputs for the next step.
|
||||
token_ids = output_token_ids.unsqueeze(dim=1).int()
|
||||
position_ids = position_ids + 1
|
||||
attn_metadata.context_lens = attn_metadata.context_lens + 1
|
||||
|
||||
block_tables = attn_metadata.block_tables
|
||||
block_number = block_tables.gather(
|
||||
1,
|
||||
position_ids.long() // self.block_size)
|
||||
block_offset = position_ids % self.block_size
|
||||
|
||||
is_padding = slot_mapping == _PAD_SLOT_ID
|
||||
slot_mapping = block_number * self.block_size + block_offset
|
||||
slot_mapping = slot_mapping.long()
|
||||
slot_mapping = torch.where(is_padding, _PAD_SLOT_ID,
|
||||
slot_mapping)
|
||||
attn_metadata.slot_mapping = slot_mapping
|
||||
|
||||
if model_input.async_callback is not None:
|
||||
model_input.async_callback()
|
||||
|
||||
if num_steps > 1:
|
||||
return []
|
||||
# Retrieve the outputs to CPU.
|
||||
next_token_ids = self.cached_step_outputs.pop(0)
|
||||
next_token_ids = next_token_ids.cpu().tolist()
|
||||
sampler_output = _make_decode_output(next_token_ids,
|
||||
model_input.seq_groups)
|
||||
return [sampler_output]
|
||||
|
||||
|
||||
class ModelWrapper(nn.Module):
|
||||
|
||||
def __init__(self, model: nn.Module):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
token_ids: torch.Tensor,
|
||||
position_ids: torch.Tensor,
|
||||
input_lens: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
p: torch.Tensor,
|
||||
num_samples: int,
|
||||
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||
) -> torch.Tensor:
|
||||
"""Executes the forward pass of the model and samples the next token.
|
||||
|
||||
Args:
|
||||
token_ids: The input token IDs of shape [batch_size, seq_len].
|
||||
position_ids: The input position IDs of shape [batch_size, seq_len].
|
||||
input_lens: The actual input lengths of shape [batch_size].
|
||||
t: The sampling temperature of shape [batch_size].
|
||||
p: The top-p probability of shape [batch_size].
|
||||
num_samples: Number of samples to draw from each logits vector.
|
||||
kv_caches: The key and value caches. They can be None during the
|
||||
memory profiling at initialization.
|
||||
"""
|
||||
batch_size, seq_len = token_ids.shape
|
||||
# Calculate the positions to sample from.
|
||||
start_indices = torch.arange(
|
||||
batch_size, dtype=torch.int32, device=input_lens.device) * seq_len
|
||||
logits_indices = start_indices + input_lens - 1
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
|
||||
# FIXME(woosuk): This is a temporary hack to avoid using the existing
|
||||
# sampler and sampling metadata.
|
||||
sampling_metadata = SamplingMetadata(
|
||||
seq_groups=[],
|
||||
selected_token_indices=logits_indices,
|
||||
categorized_sample_indices={},
|
||||
num_prompts=attn_metadata.num_prefills,
|
||||
)
|
||||
|
||||
# Skip this in memory profiling at initialization.
|
||||
if kv_caches[0][0].numel() > 0:
|
||||
# index_copy_(slot_mapping) only works when the inserted dimension
|
||||
# is 0. However, the KV cache in the Pallas backend has the shape
|
||||
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
|
||||
# work, we need to flatten the first three dimensions and modify
|
||||
# the slot_mapping accordingly.
|
||||
num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
|
||||
slot_mapping = attn_metadata.slot_mapping
|
||||
slot_mapping = slot_mapping.flatten()
|
||||
head_indices = torch.arange(0,
|
||||
num_kv_heads,
|
||||
device=slot_mapping.device,
|
||||
dtype=slot_mapping.dtype)
|
||||
head_indices *= block_size * num_blocks
|
||||
slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view(
|
||||
-1, num_kv_heads)
|
||||
slot_mapping = slot_mapping + head_indices.view(1, -1)
|
||||
slot_mapping = slot_mapping.flatten()
|
||||
attn_metadata.slot_mapping = slot_mapping
|
||||
|
||||
hidden_states = self.model(token_ids, position_ids)
|
||||
hidden_states = hidden_states.flatten(0, 1)
|
||||
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
||||
|
||||
# Argmax sampling.
|
||||
argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
|
||||
argmax_token_ids = argmax_token_ids.repeat(1, num_samples)
|
||||
|
||||
# Zero temperature means greedy decoding. Avoid division by zero.
|
||||
nonzero_t = torch.where(t != 0, t, 1.0)
|
||||
logits = logits / nonzero_t.unsqueeze(dim=1)
|
||||
if _ENABLE_TOP_P:
|
||||
logits = _apply_top_p(logits, p.unsqueeze(dim=1))
|
||||
|
||||
# Random sampling.
|
||||
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
|
||||
sampled_token_ids = torch.multinomial(probs,
|
||||
num_samples,
|
||||
replacement=True)
|
||||
if num_samples == 1:
|
||||
argmax_token_ids = argmax_token_ids.squeeze(dim=-1)
|
||||
sampled_token_ids = sampled_token_ids.squeeze(dim=-1)
|
||||
next_token_ids = torch.where(t != 0, sampled_token_ids,
|
||||
argmax_token_ids)
|
||||
return next_token_ids
|
||||
|
||||
|
||||
def _get_padded_prefill_len(x: int) -> int:
|
||||
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
|
||||
# length to be a multiple of 16. We pad the prompt length to the nearest
|
||||
# multiple of 16. This is also good for performance.
|
||||
if x <= 16:
|
||||
return 16
|
||||
return 1 << (x - 1).bit_length()
|
||||
|
||||
|
||||
def _get_padded_batch_size(batch_size: int) -> int:
|
||||
# The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16.
|
||||
# To meet this requirement in the simplest way, we set the minimal batch
|
||||
# size to 8.
|
||||
if batch_size <= 8:
|
||||
return 8
|
||||
else:
|
||||
return ((batch_size + 15) // 16) * 16
|
||||
|
||||
|
||||
def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
|
||||
logits_sorted = torch.sort(logits, dim=-1, descending=True).values
|
||||
sorted_cum_probs = torch.cumsum(logits_sorted.softmax(dim=-1), dim=-1)
|
||||
cutoff_index = torch.sum(sorted_cum_probs < p, dim=-1, keepdim=True)
|
||||
cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index)
|
||||
logits = logits.masked_fill_(logits < cutoff_logit, -float("inf"))
|
||||
return logits
|
||||
|
||||
|
||||
def _make_decode_output(
|
||||
next_token_ids: List[int],
|
||||
seq_groups: List[List[int]],
|
||||
) -> SamplerOutput:
|
||||
zero_logprob = Logprob(0.0)
|
||||
sampler_outputs = []
|
||||
batch_idx = 0
|
||||
for seq_group in seq_groups:
|
||||
seq_ids = seq_group
|
||||
seq_outputs = []
|
||||
for seq_id in seq_ids:
|
||||
next_token_id = next_token_ids[batch_idx]
|
||||
seq_outputs.append(
|
||||
SequenceOutput(seq_id, next_token_id,
|
||||
{next_token_id: zero_logprob}))
|
||||
batch_idx += 1
|
||||
sampler_outputs.append(CompletionSequenceGroupOutput(
|
||||
seq_outputs, None))
|
||||
return SamplerOutput(sampler_outputs)
|
|
@ -0,0 +1,337 @@
|
|||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.debug.profiler as xp
|
||||
import torch_xla.runtime as xr
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache, get_dtype_size
|
||||
from vllm.worker.tpu_model_runner import ExecutionMode, TPUModelRunner
|
||||
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
|
||||
LoRANotSupportedWorkerBase, WorkerBase,
|
||||
WorkerInput)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TPUWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool,
|
||||
) -> None:
|
||||
WorkerBase.__init__(self, vllm_config=vllm_config)
|
||||
self.parallel_config.rank = rank
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
self.is_driver_worker = is_driver_worker
|
||||
|
||||
assert self.device_config.device_type == "tpu"
|
||||
if self.cache_config.cache_dtype == "auto":
|
||||
self.cache_dtype = self.model_config.dtype
|
||||
else:
|
||||
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||
self.cache_config.cache_dtype]
|
||||
|
||||
self.model_runner: TPUModelRunner = TPUModelRunner(
|
||||
vllm_config=vllm_config, is_driver_worker=is_driver_worker)
|
||||
|
||||
if self.model_config.seed is None:
|
||||
self.model_config.seed = 0
|
||||
|
||||
if vllm_config.lora_config is not None:
|
||||
raise NotImplementedError(
|
||||
"The V0 TPU backend doesn't support LoRA serving")
|
||||
|
||||
def init_device(self) -> None:
|
||||
os.environ["PJRT_DEVICE"] = "TPU"
|
||||
torch.set_grad_enabled(False)
|
||||
torch.set_default_dtype(self.model_config.dtype)
|
||||
|
||||
# NOTE(woosuk): This is just to initialize the TP group and broadcast
|
||||
# the input objects on CPU. The all-reduce and all-gather ops on TPU
|
||||
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their
|
||||
# own context.
|
||||
init_distributed_environment(
|
||||
world_size=self.parallel_config.world_size,
|
||||
rank=self.rank,
|
||||
local_rank=self.local_rank,
|
||||
distributed_init_method=self.distributed_init_method,
|
||||
backend="gloo",
|
||||
)
|
||||
ensure_model_parallel_initialized(
|
||||
self.parallel_config.tensor_parallel_size,
|
||||
self.parallel_config.pipeline_parallel_size)
|
||||
|
||||
# Device initialization should happen after initializing the distributed
|
||||
# runtime.
|
||||
self.device = xm.xla_device()
|
||||
self.device_config.device = self.device
|
||||
|
||||
# Set random seed.
|
||||
set_random_seed(self.model_config.seed)
|
||||
xm.set_rng_state(self.model_config.seed, self.device)
|
||||
|
||||
# Increase the cache size limit, which is the maximum number of
|
||||
# dynamo graphs that can be compiled.
|
||||
# NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and
|
||||
# 30-40 graphs for decode. 128 is an arbitrary safe number.
|
||||
torch._dynamo.config.cache_size_limit = 128
|
||||
# Use persistent cache to avoid XLA recompilation.
|
||||
# NOTE(woosuk): Set per-rank cache path since different ranks
|
||||
# can have slightly different XLA graphs.
|
||||
world_size = self.parallel_config.world_size
|
||||
rank = xr.global_ordinal()
|
||||
# The PyTorch/XLA compilation cache uses the Torch IR to generate keys.
|
||||
# Consequently, changes in optimization flags, which affect compilation
|
||||
# results, don't change the cache key. This can result in the wrong
|
||||
# compilation being used. To prevent this, disabling the XLA compilation
|
||||
# cache during development is recommended.We can disable it by
|
||||
# `export VLLM_XLA_CACHE_PATH=`
|
||||
if envs.VLLM_XLA_CACHE_PATH:
|
||||
per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH,
|
||||
f"tp{world_size}_rank{rank}")
|
||||
xr.initialize_cache(per_rank_path, readonly=False)
|
||||
|
||||
self.profiler = None
|
||||
if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
|
||||
# For TPU, we can only have 1 active profiler session for 1 profiler
|
||||
# server. So we only profile on rank0.
|
||||
self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||
logger.info("Profiling enabled. Traces will be saved to: %s",
|
||||
self.profile_dir)
|
||||
self.profiler = xp.start_server(9012)
|
||||
|
||||
def start_profile(self):
|
||||
if self.rank < 1:
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
xp.start_trace(self.profile_dir)
|
||||
|
||||
def stop_profile(self):
|
||||
if self.rank < 1:
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
xp.stop_trace()
|
||||
|
||||
def load_model(self):
|
||||
self.model_runner.load_model()
|
||||
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
head_size = self.model_config.get_head_size()
|
||||
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||
|
||||
# use an empty tensor instead of `None`` to force Dynamo to pass
|
||||
# it by reference, rather by specializing on the value ``None``.
|
||||
# the `dtype` argument does not matter, and we use `float32` as
|
||||
# a placeholder (it has wide hardware support).
|
||||
kv_caches = [(torch.tensor([], dtype=torch.float32,
|
||||
device=self.device),
|
||||
torch.tensor([], dtype=torch.float32,
|
||||
device=self.device))
|
||||
for _ in range(num_layers)]
|
||||
bind_kv_cache(self.compilation_config.static_forward_context,
|
||||
[kv_caches])
|
||||
self.model_runner._dummy_run(
|
||||
batch_size=1,
|
||||
seq_len=self.scheduler_config.max_num_batched_tokens,
|
||||
kv_caches=kv_caches,
|
||||
exec_mode=ExecutionMode.PREFILL,
|
||||
)
|
||||
# Synchronize before measuring the memory usage.
|
||||
xm.wait_device_ops()
|
||||
|
||||
# Get the maximum amount of memory used by the model weights and
|
||||
# intermediate activations.
|
||||
m = xm.get_memory_info(self.device)
|
||||
total_memory_size = m["bytes_limit"]
|
||||
profiled = m["peak_bytes_used"] # Weights + intermediate activations.
|
||||
|
||||
# Calculate the TPU KV cache size based on profiling.
|
||||
usable_memory_size = int(total_memory_size *
|
||||
self.cache_config.gpu_memory_utilization)
|
||||
tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
|
||||
dtype_bytes = get_dtype_size(self.cache_dtype)
|
||||
block_size_bytes = (dtype_bytes * self.cache_config.block_size *
|
||||
num_layers * 2 * head_size * num_kv_heads)
|
||||
num_tpu_blocks = tpu_kv_cache_bytes // block_size_bytes
|
||||
num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8.
|
||||
|
||||
# Calculate the CPU KV cache size based on the config.
|
||||
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
|
||||
block_size_bytes)
|
||||
num_cpu_blocks = (num_cpu_blocks // 8) * 8 # Round down to 8.
|
||||
return num_tpu_blocks, num_cpu_blocks
|
||||
|
||||
def initialize_cache(
|
||||
self,
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
) -> None:
|
||||
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||
self.block_size = self.cache_config.block_size
|
||||
|
||||
dtype = self.cache_dtype
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||
head_size = self.model_config.get_head_size()
|
||||
|
||||
self.cpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
|
||||
self.tpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
|
||||
tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
|
||||
num_gpu_blocks, self.block_size, num_kv_heads, head_size)
|
||||
cpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
|
||||
num_cpu_blocks, self.block_size, num_kv_heads, head_size)
|
||||
for _ in range(num_layers):
|
||||
tpu_k_cache = torch.zeros(tpu_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
tpu_v_cache = torch.zeros_like(tpu_k_cache)
|
||||
self.tpu_cache.append((tpu_k_cache, tpu_v_cache))
|
||||
cpu_k_cache = torch.zeros(cpu_cache_shape,
|
||||
dtype=dtype,
|
||||
device="cpu")
|
||||
cpu_v_cache = torch.zeros_like(cpu_k_cache)
|
||||
self.cpu_cache.append((cpu_k_cache, cpu_v_cache))
|
||||
bind_kv_cache(self.compilation_config.static_forward_context,
|
||||
[self.tpu_cache])
|
||||
self._warmup_model()
|
||||
|
||||
def _warmup_model(self) -> None:
|
||||
# FIXME(woosuk): Here we are abusing `enforce_eager` which is defined
|
||||
# for CUDA graphs. We should refactor this part.
|
||||
if not self.model_config.enforce_eager:
|
||||
# Warm up the model with all possible input shapes so that
|
||||
# compilation never happens during the actual execution.
|
||||
# This may take ~30 mins for the first run and ~20 mins for the
|
||||
# subsequent runs.
|
||||
# If `enforce_eager` is True, the ahead-of-time compilation is
|
||||
# skipped and the compilation happens during the actual execution,
|
||||
# which is bad for performance but useful for development.
|
||||
self.model_runner.warmup_model(self.tpu_cache)
|
||||
|
||||
def get_cache_block_size_bytes(self) -> int:
|
||||
head_size = self.model_config.get_head_size()
|
||||
num_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
|
||||
key_cache_block = self.cache_config.block_size * num_heads * head_size
|
||||
value_cache_block = key_cache_block
|
||||
total = num_layers * (key_cache_block + value_cache_block)
|
||||
dtype_size = get_dtype_size(self.cache_dtype)
|
||||
return dtype_size * total
|
||||
|
||||
@property
|
||||
def do_metadata_broadcast(self) -> bool:
|
||||
return self.parallel_config.tensor_parallel_size > 1
|
||||
|
||||
@property
|
||||
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
|
||||
# NOTE(woosuk): This assumes virtual_engine == 0, i.e., no pipeline
|
||||
# parallelism.
|
||||
return [self.tpu_cache]
|
||||
|
||||
def prepare_worker_input(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest,
|
||||
) -> WorkerInput:
|
||||
virtual_engine = execute_model_req.virtual_engine
|
||||
num_seq_groups = len(execute_model_req.seq_group_metadata_list)
|
||||
blocks_to_swap_in = _make_src_to_dst(
|
||||
execute_model_req.blocks_to_swap_in, "cpu", self.device)
|
||||
blocks_to_swap_out = _make_src_to_dst(
|
||||
execute_model_req.blocks_to_swap_out, self.device, "cpu")
|
||||
blocks_to_copy = _make_src_to_dst(execute_model_req.blocks_to_copy,
|
||||
self.device, self.device)
|
||||
return WorkerInput(
|
||||
num_seq_groups=num_seq_groups,
|
||||
blocks_to_swap_in=blocks_to_swap_in,
|
||||
blocks_to_swap_out=blocks_to_swap_out,
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
virtual_engine=virtual_engine,
|
||||
)
|
||||
|
||||
def execute_worker(self, worker_input: WorkerInput) -> None:
|
||||
virtual_engine = worker_input.virtual_engine
|
||||
assert virtual_engine == 0
|
||||
attn_backend = self.model_runner.attn_backend
|
||||
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||
|
||||
# Issue cache operations.
|
||||
if worker_input.blocks_to_swap_in is not None:
|
||||
src_indices, dst_indices = worker_input.blocks_to_swap_in
|
||||
if src_indices.numel() > 0:
|
||||
# Swap from CPU to TPU.
|
||||
for i in range(num_layers):
|
||||
tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
|
||||
cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
|
||||
k = cpu_k_cache[:, src_indices].to(self.device)
|
||||
v = cpu_v_cache[:, src_indices].to(self.device)
|
||||
_insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache)
|
||||
|
||||
if worker_input.blocks_to_swap_out is not None:
|
||||
src_indices, dst_indices = worker_input.blocks_to_swap_out
|
||||
if src_indices.numel() > 0:
|
||||
# Swap from TPU to CPU.
|
||||
for i in range(num_layers):
|
||||
tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
|
||||
cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
|
||||
cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices]
|
||||
cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices]
|
||||
|
||||
if worker_input.blocks_to_copy is not None:
|
||||
src_indices, dst_indices = worker_input.blocks_to_copy
|
||||
if src_indices.numel() > 0:
|
||||
attn_backend.copy_blocks(self.tpu_cache,
|
||||
(src_indices, dst_indices))
|
||||
|
||||
|
||||
def _make_src_to_dst(
|
||||
mapping: List[Tuple[int, int]],
|
||||
src_device: Union[torch.device, str],
|
||||
dst_device: Union[torch.device, str],
|
||||
) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
|
||||
if not mapping:
|
||||
return None
|
||||
|
||||
src_indices = [i for i, _ in mapping]
|
||||
dst_indices = [i for _, i in mapping]
|
||||
src_indices = torch.tensor(src_indices,
|
||||
device=src_device,
|
||||
dtype=torch.int64)
|
||||
dst_indices = torch.tensor(dst_indices,
|
||||
device=dst_device,
|
||||
dtype=torch.int64)
|
||||
return src_indices, dst_indices
|
||||
|
||||
|
||||
@torch.compile(backend="openxla")
|
||||
def _insert_kv(
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
tpu_k_cache: torch.Tensor,
|
||||
tpu_v_cache: torch.Tensor,
|
||||
) -> None:
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(tpu_k_cache, True)
|
||||
torch.ops.xla.dynamo_set_buffer_donor_(tpu_v_cache, True)
|
||||
tpu_k_cache[:, indices] = k
|
||||
tpu_v_cache[:, indices] = v
|
|
@ -0,0 +1,606 @@
|
|||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
import time
|
||||
import weakref
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
|
||||
Type, TypeVar)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention import get_attn_backend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadataCache
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
||||
MultiModalKwargs, MultiModalPlaceholderMap,
|
||||
MultiModalRegistry)
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
|
||||
from vllm.utils import DeviceMemoryProfiler, GiB_bytes, make_tensor_with_pad
|
||||
from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata
|
||||
from vllm.worker.model_runner_base import (
|
||||
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
|
||||
_add_attn_metadata_broadcastable_dict,
|
||||
_add_sampling_metadata_broadcastable_dict,
|
||||
_init_attn_metadata_from_tensor_dict,
|
||||
_init_sampling_metadata_from_tensor_dict)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_PAD_SLOT_ID = -1
|
||||
|
||||
TModelInputForXPU = TypeVar('TModelInputForXPU', bound="ModelInputForXPU")
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelInputForXPU(ModelRunnerInputBase):
|
||||
"""
|
||||
Used by the NeuronModelRunner.
|
||||
"""
|
||||
input_tokens: Optional[torch.Tensor] = None
|
||||
input_positions: Optional[torch.Tensor] = None
|
||||
attn_metadata: Optional["AttentionMetadata"] = None
|
||||
multi_modal_kwargs: Optional[BatchedTensorInputs] = None
|
||||
virtual_engine: Optional[int] = None
|
||||
seq_lens: Optional[List[int]] = None
|
||||
query_lens: Optional[List[int]] = None
|
||||
async_callback: Optional[Callable] = None
|
||||
|
||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||
tensor_dict = {
|
||||
"input_tokens": self.input_tokens,
|
||||
"input_positions": self.input_positions,
|
||||
}
|
||||
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||
|
||||
return tensor_dict
|
||||
|
||||
@classmethod
|
||||
def from_broadcasted_tensor_dict(
|
||||
cls: Type[TModelInputForXPU],
|
||||
tensor_dict: Dict[str, Any],
|
||||
attn_backend: Optional["AttentionBackend"] = None,
|
||||
) -> TModelInputForXPU:
|
||||
if attn_backend is not None:
|
||||
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
||||
attn_backend, tensor_dict)
|
||||
return cls(**tensor_dict)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ModelInputForXPUWithSamplingMetadata(ModelInputForXPU):
|
||||
"""
|
||||
Used by the ModelRunner.
|
||||
"""
|
||||
sampling_metadata: Optional["SamplingMetadata"] = None
|
||||
|
||||
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||
tensor_dict = {
|
||||
"input_tokens": self.input_tokens,
|
||||
"input_positions": self.input_positions,
|
||||
}
|
||||
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||
_add_sampling_metadata_broadcastable_dict(tensor_dict,
|
||||
self.sampling_metadata)
|
||||
return tensor_dict
|
||||
|
||||
@classmethod
|
||||
def from_broadcasted_tensor_dict(
|
||||
cls,
|
||||
tensor_dict: Dict[str, Any],
|
||||
attn_backend: Optional["AttentionBackend"] = None,
|
||||
) -> "ModelInputForXPUWithSamplingMetadata":
|
||||
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
|
||||
if attn_backend is not None:
|
||||
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
||||
attn_backend, tensor_dict)
|
||||
return cls(**tensor_dict)
|
||||
|
||||
|
||||
class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
|
||||
|
||||
def __init__(self,
|
||||
runner: "XPUModelRunner",
|
||||
finished_requests_ids: Optional[List[str]] = None) -> None:
|
||||
super().__init__()
|
||||
self.runner = runner
|
||||
self.model_input_cls = self.runner._model_input_cls
|
||||
self.attn_backend = self.runner.attn_backend
|
||||
self.sliding_window = self.runner.sliding_window
|
||||
self.block_size = self.runner.block_size
|
||||
self.device = self.runner.device
|
||||
|
||||
def prepare(self,
|
||||
finished_requests_ids: Optional[List[str]] = None) -> None:
|
||||
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
|
||||
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
|
||||
self.seq_group_metadata_list.append(seq_group_metadata)
|
||||
|
||||
def build(self) -> ModelInputForXPU:
|
||||
is_prompt = self.seq_group_metadata_list[0].is_prompt
|
||||
# Prepare input tensors.
|
||||
if is_prompt:
|
||||
(input_tokens, input_positions, attn_metadata, seq_lens,
|
||||
multi_modal_kwargs) = self._prepare_prompt(
|
||||
self.seq_group_metadata_list)
|
||||
else:
|
||||
(input_tokens, input_positions,
|
||||
attn_metadata) = self._prepare_decode(
|
||||
self.seq_group_metadata_list)
|
||||
seq_lens = None
|
||||
multi_modal_kwargs = None
|
||||
|
||||
return self.model_input_cls(
|
||||
input_tokens=input_tokens,
|
||||
input_positions=input_positions,
|
||||
attn_metadata=attn_metadata,
|
||||
multi_modal_kwargs=multi_modal_kwargs,
|
||||
seq_lens=seq_lens,
|
||||
query_lens=seq_lens,
|
||||
)
|
||||
|
||||
def _prepare_prompt(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
|
||||
BatchedTensorInputs]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[int] = []
|
||||
input_positions: List[int] = []
|
||||
slot_mapping: List[int] = []
|
||||
seq_lens: List[int] = []
|
||||
multi_modal_kwargs_list: List[MultiModalKwargs] = []
|
||||
multi_modal_placeholder_maps: Dict[
|
||||
str,
|
||||
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert seq_group_metadata.is_prompt
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
assert len(seq_ids) == 1
|
||||
seq_id = seq_ids[0]
|
||||
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
prompt_tokens = seq_data.get_token_ids()
|
||||
computed_len = seq_data.get_num_computed_tokens()
|
||||
seq_len = len(prompt_tokens)
|
||||
|
||||
seq_lens.append(seq_len) # Prompt token num
|
||||
input_tokens.extend(prompt_tokens) # Token ids
|
||||
|
||||
# Token position ids
|
||||
# NOTE(woosuk): Here we assume that the first token in the prompt
|
||||
# is always the first token in the sequence.
|
||||
positions_range = range(computed_len, seq_len)
|
||||
input_positions.extend(list(positions_range))
|
||||
|
||||
if seq_group_metadata.multi_modal_data:
|
||||
# NOTE: mm_kwargs only includes the subset of multi-modal items
|
||||
# that intersect with the current prefill positions.
|
||||
mm_kwargs, placeholder_maps = MultiModalPlaceholderMap \
|
||||
.from_seq_group(seq_group_metadata, positions_range)
|
||||
|
||||
multi_modal_kwargs_list.append(mm_kwargs)
|
||||
|
||||
for modality, placeholder_map in placeholder_maps.items():
|
||||
multi_modal_placeholder_maps[modality].extend(
|
||||
placeholder_map)
|
||||
|
||||
if seq_group_metadata.block_tables is None:
|
||||
# During memory profiling, the block tables are not initialized
|
||||
# yet. In this case, we just use a dummy slot mapping.
|
||||
slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
|
||||
continue
|
||||
|
||||
# Compute the slot mapping.
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
|
||||
# where start_idx is max(0, seq_len - sliding_window).
|
||||
# For example, if the prompt len is 10, sliding window is 8, and
|
||||
# block size is 4, the first two tokens are masked and the slot
|
||||
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
|
||||
start_idx = 0
|
||||
if self.sliding_window is not None:
|
||||
start_idx = max(0, seq_len - self.sliding_window)
|
||||
|
||||
for i in range(computed_len, seq_len):
|
||||
if i < start_idx:
|
||||
slot_mapping.append(_PAD_SLOT_ID)
|
||||
continue
|
||||
|
||||
block_number = block_table[i //
|
||||
self.block_size] # type: ignore
|
||||
block_offset = i % self.block_size # type: ignore
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
|
||||
num_prompt_tokens = len(input_tokens)
|
||||
|
||||
input_tokens = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.device) # type: ignore
|
||||
input_positions = torch.tensor(input_positions,
|
||||
dtype=torch.long,
|
||||
device=self.device) # type: ignore
|
||||
slot_mapping = torch.tensor(slot_mapping,
|
||||
dtype=torch.long,
|
||||
device=self.device) # type: ignore
|
||||
placeholder_index_maps = {
|
||||
modality: placeholder_map.index_map()
|
||||
for modality, placeholder_map in
|
||||
multi_modal_placeholder_maps.items()
|
||||
}
|
||||
|
||||
max_seqlen = max(seq_lens)
|
||||
tmp = [0]
|
||||
tmp.extend(seq_lens)
|
||||
seqlen = torch.tensor(tmp)
|
||||
seqlen_q = torch.cumsum(seqlen, dim=0).to(device=self.device)
|
||||
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
is_prompt=True,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||
enable_kv_scales_calculation=False,
|
||||
seq_lens=seq_lens,
|
||||
seqlen_q=seqlen_q,
|
||||
max_seqlen=max_seqlen,
|
||||
seq_lens_tensor=torch.tensor([]),
|
||||
max_decode_seq_len=0,
|
||||
num_prefills=len(seq_lens),
|
||||
num_prefill_tokens=num_prompt_tokens,
|
||||
num_decode_tokens=0,
|
||||
block_tables=torch.tensor([], device=self.device, dtype=torch.int),
|
||||
)
|
||||
|
||||
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
|
||||
|
||||
return (input_tokens, input_positions, attn_metadata, seq_lens,
|
||||
multi_modal_kwargs)
|
||||
|
||||
def _prepare_decode(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]:
|
||||
assert len(seq_group_metadata_list) > 0
|
||||
input_tokens: List[int] = []
|
||||
input_positions: List[int] = []
|
||||
slot_mapping: List[int] = []
|
||||
seq_lens: List[int] = []
|
||||
block_tables: List[List[int]] = []
|
||||
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
assert not seq_group_metadata.is_prompt
|
||||
assert seq_group_metadata.token_chunk_size == 1
|
||||
|
||||
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||
|
||||
for seq_id in seq_ids:
|
||||
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||
generation_token = seq_data.get_last_token_id()
|
||||
input_tokens.append(generation_token)
|
||||
|
||||
seq_len = seq_data.get_len()
|
||||
position = seq_len - 1
|
||||
input_positions.append(position)
|
||||
|
||||
seq_len = seq_len if self.sliding_window is None else min(
|
||||
seq_len, self.sliding_window)
|
||||
seq_lens.append(seq_len)
|
||||
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
block_number = block_table[position // self.block_size]
|
||||
block_offset = position % self.block_size
|
||||
slot = block_number * self.block_size + block_offset
|
||||
slot_mapping.append(slot)
|
||||
|
||||
if self.sliding_window is not None:
|
||||
sliding_window_blocks = (self.sliding_window //
|
||||
self.block_size)
|
||||
block_table = block_table[-sliding_window_blocks:]
|
||||
block_tables.append(block_table)
|
||||
|
||||
max_decode_seq_len = max(seq_lens)
|
||||
|
||||
input_tokens = torch.tensor(input_tokens,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
input_positions = torch.tensor(input_positions,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
slot_mapping = torch.tensor(slot_mapping,
|
||||
dtype=torch.long,
|
||||
device=self.device)
|
||||
seq_lens_tensor = torch.tensor(seq_lens,
|
||||
dtype=torch.int,
|
||||
device=self.device)
|
||||
|
||||
block_tables = make_tensor_with_pad(
|
||||
block_tables,
|
||||
pad=0,
|
||||
dtype=torch.int,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
attn_metadata = self.attn_backend.make_metadata(
|
||||
is_prompt=False,
|
||||
slot_mapping=slot_mapping,
|
||||
multi_modal_placeholder_index_maps=None,
|
||||
enable_kv_scales_calculation=False,
|
||||
seq_lens=seq_lens,
|
||||
seqlen_q=torch.tensor([]),
|
||||
max_seqlen=0,
|
||||
seq_lens_tensor=seq_lens_tensor,
|
||||
max_decode_seq_len=max_decode_seq_len,
|
||||
num_prefill_tokens=0,
|
||||
num_decode_tokens=len(input_tokens),
|
||||
num_prefills=0,
|
||||
block_tables=block_tables,
|
||||
)
|
||||
return (
|
||||
input_tokens,
|
||||
input_positions,
|
||||
attn_metadata,
|
||||
)
|
||||
|
||||
|
||||
class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
|
||||
_model_input_cls: Type[ModelInputForXPUWithSamplingMetadata] = (
|
||||
ModelInputForXPUWithSamplingMetadata)
|
||||
_builder_cls: Type[ModelInputForXPUBuilder] = ModelInputForXPUBuilder
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
is_driver_worker: bool = False,
|
||||
return_hidden_states: bool = False,
|
||||
input_registry: InputRegistry = INPUT_REGISTRY,
|
||||
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||
):
|
||||
|
||||
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
|
||||
model_config = self.model_config
|
||||
cache_config = self.cache_config
|
||||
self.is_driver_worker = is_driver_worker
|
||||
self.return_hidden_states = return_hidden_states
|
||||
|
||||
self.device = self.device_config.device
|
||||
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.sliding_window = model_config.get_sliding_window()
|
||||
self.block_size = cache_config.block_size
|
||||
|
||||
self.attn_backend = get_attn_backend(
|
||||
self.model_config.get_head_size(),
|
||||
self.model_config.dtype,
|
||||
self.kv_cache_dtype,
|
||||
self.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
)
|
||||
|
||||
# Multi-modal data support
|
||||
self.input_registry = input_registry
|
||||
self.mm_registry = mm_registry
|
||||
|
||||
# Lazy initialization.
|
||||
self.model: nn.Module # Set after init_Model
|
||||
self.sampler = get_sampler()
|
||||
|
||||
self.sampling_metadata_cache: SamplingMetadataCache = \
|
||||
SamplingMetadataCache() \
|
||||
if self.parallel_config.pipeline_parallel_size == 1 else None
|
||||
|
||||
self.builder = self._builder_cls(weakref.proxy(self))
|
||||
|
||||
def load_model(self) -> None:
|
||||
with DeviceMemoryProfiler() as m:
|
||||
self.model = get_model(vllm_config=self.vllm_config)
|
||||
|
||||
self.model_memory_usage = m.consumed_memory
|
||||
logger.info("Loading model weights took %.4f GiB",
|
||||
self.model_memory_usage / GiB_bytes)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
@property
|
||||
def vocab_size(self) -> int:
|
||||
return self.model_config.get_vocab_size()
|
||||
|
||||
@torch.inference_mode()
|
||||
def profile_run(self) -> None:
|
||||
# Enable top-k sampling to reflect the accurate memory usage.
|
||||
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
|
||||
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
||||
max_num_seqs = self.scheduler_config.max_num_seqs
|
||||
|
||||
# Profile memory usage with max_num_sequences sequences and the total
|
||||
# number of tokens equal to max_num_batched_tokens.
|
||||
seqs: List[SequenceGroupMetadata] = []
|
||||
# Additional GPU memory may be needed for multi-modal encoding, which
|
||||
# needs to be accounted for when calculating the GPU blocks for
|
||||
# vLLM blocker manager.
|
||||
# To exercise the worst scenario for GPU memory consumption,
|
||||
# the number of seqs (batch_size) is chosen to maximize the number
|
||||
# of images processed.
|
||||
max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
|
||||
self.model_config)
|
||||
if max_mm_tokens > 0:
|
||||
max_num_seqs_orig = max_num_seqs
|
||||
max_num_seqs = min(max_num_seqs,
|
||||
max_num_batched_tokens // max_mm_tokens)
|
||||
if max_num_seqs < 1:
|
||||
expr = (f"min({max_num_seqs_orig}, "
|
||||
f"{max_num_batched_tokens} // {max_mm_tokens})")
|
||||
logger.warning(
|
||||
"Computed max_num_seqs (%s) to be less than 1. "
|
||||
"Setting it to the minimum value of 1.", expr)
|
||||
max_num_seqs = 1
|
||||
|
||||
batch_size = 0
|
||||
for group_id in range(max_num_seqs):
|
||||
seq_len = (max_num_batched_tokens // max_num_seqs +
|
||||
(group_id < max_num_batched_tokens % max_num_seqs))
|
||||
batch_size += seq_len
|
||||
|
||||
dummy_data = self.input_registry \
|
||||
.dummy_data_for_profiling(self.model_config,
|
||||
seq_len,
|
||||
self.mm_registry)
|
||||
|
||||
seq = SequenceGroupMetadata(
|
||||
request_id=str(group_id),
|
||||
is_prompt=True,
|
||||
seq_data={group_id: dummy_data.seq_data},
|
||||
sampling_params=sampling_params,
|
||||
block_tables=None,
|
||||
lora_request=None,
|
||||
multi_modal_data=dummy_data.multi_modal_data,
|
||||
multi_modal_placeholders=dummy_data.multi_modal_placeholders)
|
||||
seqs.append(seq)
|
||||
|
||||
finished_requests_ids = [seq.request_id for seq in seqs]
|
||||
model_input = self.prepare_model_input(
|
||||
seqs, finished_requests_ids=finished_requests_ids)
|
||||
intermediate_tensors = None
|
||||
if not get_pp_group().is_first_rank:
|
||||
intermediate_tensors = self.model.make_empty_intermediate_tensors(
|
||||
batch_size=batch_size,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device)
|
||||
self.execute_model(model_input, None, intermediate_tensors)
|
||||
torch.xpu.synchronize()
|
||||
return
|
||||
|
||||
def make_model_input_from_broadcasted_tensor_dict(
|
||||
self,
|
||||
tensor_dict: Dict[str,
|
||||
Any]) -> ModelInputForXPUWithSamplingMetadata:
|
||||
return (
|
||||
ModelInputForXPUWithSamplingMetadata.from_broadcasted_tensor_dict(
|
||||
tensor_dict,
|
||||
attn_backend=self.attn_backend,
|
||||
))
|
||||
|
||||
def _prepare_model_input_tensors(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
) -> ModelInputForXPUWithSamplingMetadata:
|
||||
"""Helper method to prepare the model input based on a given sequence
|
||||
group. Prepares metadata needed for the base model forward pass but not
|
||||
metadata for possible additional steps, e.g., sampling.
|
||||
|
||||
"""
|
||||
builder = self.builder
|
||||
builder.prepare(finished_requests_ids)
|
||||
for seq_group_metadata in seq_group_metadata_list:
|
||||
builder.add_seq_group(seq_group_metadata)
|
||||
|
||||
return builder.build() # type: ignore
|
||||
|
||||
def prepare_model_input(
|
||||
self,
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
virtual_engine: int = 0,
|
||||
finished_requests_ids: Optional[List[str]] = None
|
||||
) -> ModelInputForXPUWithSamplingMetadata:
|
||||
"""Prepare the model input based on a given sequence group, including
|
||||
metadata for the sampling step.
|
||||
|
||||
"""
|
||||
model_input = self._prepare_model_input_tensors(
|
||||
seq_group_metadata_list, finished_requests_ids)
|
||||
# Sampling metadata is only required for the final pp group
|
||||
generators = self.get_generators(finished_requests_ids)
|
||||
sampling_metadata = SamplingMetadata.prepare(
|
||||
seq_group_metadata_list,
|
||||
model_input.seq_lens,
|
||||
model_input.query_lens,
|
||||
self.device,
|
||||
pin_memory=False,
|
||||
generators=generators,
|
||||
cache=self.sampling_metadata_cache)
|
||||
|
||||
return dataclasses.replace(model_input,
|
||||
sampling_metadata=sampling_metadata,
|
||||
virtual_engine=virtual_engine)
|
||||
|
||||
@torch.inference_mode()
|
||||
def execute_model(
|
||||
self,
|
||||
model_input: ModelInputForXPUWithSamplingMetadata,
|
||||
kv_caches: List[torch.Tensor],
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
num_steps: int = 1,
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
if num_steps > 1:
|
||||
raise ValueError(
|
||||
"XPUModelRunner does not support multi-step execution.")
|
||||
|
||||
model_executable = self.model
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_start_time = time.time()
|
||||
with set_forward_context(model_input.attn_metadata, self.vllm_config,
|
||||
model_input.virtual_engine):
|
||||
hidden_or_intermediate_states = model_executable(
|
||||
input_ids=model_input.input_tokens,
|
||||
positions=model_input.input_positions,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
**MultiModalKwargs.as_kwargs(
|
||||
model_input.multi_modal_kwargs or {},
|
||||
device=self.device,
|
||||
),
|
||||
)
|
||||
# Compute the logits in the last pipeline stage.
|
||||
if not get_pp_group().is_last_rank:
|
||||
return hidden_or_intermediate_states
|
||||
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time):
|
||||
model_forward_end_time = time.time()
|
||||
|
||||
# Compute the logits.
|
||||
logits = self.model.compute_logits(hidden_or_intermediate_states,
|
||||
model_input.sampling_metadata)
|
||||
|
||||
# Only perform sampling in the driver worker.
|
||||
if not self.is_driver_worker:
|
||||
return []
|
||||
|
||||
if model_input.async_callback is not None:
|
||||
model_input.async_callback()
|
||||
|
||||
# Sample the next token.
|
||||
output: SamplerOutput = self.sampler(
|
||||
logits=logits,
|
||||
sampling_metadata=model_input.sampling_metadata,
|
||||
)
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_forward_time
|
||||
and output is not None):
|
||||
model_forward_time = (model_forward_end_time -
|
||||
model_forward_start_time)
|
||||
# If there are multiple workers, we are still tracking the latency
|
||||
# from the start time of the driver worker to the end time of the
|
||||
# driver worker. The model forward time will then end up covering
|
||||
# the communication time as well.
|
||||
output.model_forward_time = model_forward_time
|
||||
|
||||
return [output]
|
|
@ -0,0 +1,186 @@
|
|||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""A XPU worker class."""
|
||||
import gc
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import intel_extension_for_pytorch # noqa: F401
|
||||
import oneccl_bindings_for_pytorch # noqa: F401
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
init_distributed_environment)
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import set_random_seed
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.worker import Worker
|
||||
from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase
|
||||
from vllm.worker.xpu_model_runner import XPUModelRunner
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class XPUWorker(LoRANotSupportedWorkerBase, Worker):
|
||||
"""A worker class that executes (a partition of) the model on a GPU.
|
||||
|
||||
Each worker is associated with a single XPU device. The worker is
|
||||
responsible for maintaining the KV cache and executing the model on the
|
||||
XPU. In case of distributed inference, each worker is assigned a partition
|
||||
of the model.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
is_driver_worker: bool = False,
|
||||
) -> None:
|
||||
WorkerBase.__init__(self, vllm_config=vllm_config)
|
||||
device_config = self.device_config
|
||||
parallel_config = self.parallel_config
|
||||
assert device_config.device_type == "xpu"
|
||||
assert current_platform.is_xpu()
|
||||
|
||||
self.parallel_config.rank = rank
|
||||
|
||||
self.local_rank = local_rank
|
||||
self.rank = rank
|
||||
self.distributed_init_method = distributed_init_method
|
||||
self.is_driver_worker = is_driver_worker
|
||||
if parallel_config and is_driver_worker:
|
||||
assert rank % parallel_config.tensor_parallel_size == 0, \
|
||||
"Driver worker should be rank 0 of tensor parallel group."
|
||||
|
||||
self.model_runner = XPUModelRunner( # type: ignore
|
||||
vllm_config=vllm_config,
|
||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||
is_driver_worker=is_driver_worker,
|
||||
)
|
||||
# Uninitialized cache engine. Will be initialized by
|
||||
# initialize_cache.
|
||||
self.cache_engine: List[CacheEngine]
|
||||
self.gpu_cache: Optional[List[List[torch.Tensor]]]
|
||||
|
||||
def init_device(self) -> None:
|
||||
if self.device_config.device.type == "xpu" and current_platform.is_xpu(
|
||||
):
|
||||
self.device = torch.device(f"xpu:{self.local_rank}")
|
||||
torch.xpu.set_device(self.device)
|
||||
torch.xpu.empty_cache()
|
||||
self.init_gpu_memory = torch.xpu.get_device_properties(
|
||||
self.local_rank).total_memory
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Not support device type: {self.device_config.device}")
|
||||
# Initialize the distributed environment.
|
||||
self.init_worker_distributed_environment()
|
||||
# Initialize the model.
|
||||
set_random_seed(self.model_config.seed)
|
||||
|
||||
# keep this method for `empty_cache` and `synchronize` api
|
||||
@torch.inference_mode()
|
||||
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||
"""Profiles the peak memory usage of the model to determine how many
|
||||
KV blocks may be allocated without OOMs.
|
||||
|
||||
The engine will first conduct a profiling of the existing memory usage.
|
||||
Then, it calculate the maximum possible number of GPU and CPU blocks
|
||||
that can be allocated with the remaining free memory.
|
||||
|
||||
Tip:
|
||||
You may limit the usage of GPU memory
|
||||
by adjusting the `gpu_memory_utilization` parameter.
|
||||
"""
|
||||
# Profile the memory usage of the model and get the maximum number of
|
||||
# cache blocks that can be allocated with the remaining free memory.
|
||||
torch.xpu.empty_cache()
|
||||
|
||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||
# of the model.
|
||||
self.model_runner.profile_run()
|
||||
|
||||
# Calculate the number of blocks that can be allocated with the
|
||||
# profiled peak memory.
|
||||
torch.xpu.synchronize()
|
||||
used_memory = torch.xpu.memory_allocated()
|
||||
total_gpu_memory = torch.xpu.get_device_properties(
|
||||
self.local_rank).total_memory
|
||||
free_gpu_memory = total_gpu_memory - used_memory
|
||||
|
||||
# NOTE(woosuk): Here we assume that the other processes using the same
|
||||
# GPU did not change their memory usage during the profiling.
|
||||
peak_memory = self.init_gpu_memory - free_gpu_memory
|
||||
assert peak_memory > 0, (
|
||||
"Error in memory profiling. "
|
||||
f"Initial free memory {self.init_gpu_memory}, current free memory"
|
||||
f" {free_gpu_memory}. This happens when the GPU memory was "
|
||||
"not properly cleaned up before initializing the vLLM instance.")
|
||||
|
||||
cache_block_size = self.get_cache_block_size_bytes()
|
||||
num_gpu_blocks = int(
|
||||
(total_gpu_memory * self.cache_config.gpu_memory_utilization -
|
||||
peak_memory) // cache_block_size)
|
||||
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
|
||||
cache_block_size)
|
||||
num_gpu_blocks = max(num_gpu_blocks, 0)
|
||||
num_cpu_blocks = max(num_cpu_blocks, 0)
|
||||
gc.collect()
|
||||
torch.xpu.empty_cache()
|
||||
return num_gpu_blocks, num_cpu_blocks
|
||||
|
||||
def _warm_up_model(self) -> None:
|
||||
# IPEX don't support capture graph yet
|
||||
pass
|
||||
|
||||
def init_worker_distributed_environment(self) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
|
||||
parallel_config = self.parallel_config
|
||||
rank = self.rank
|
||||
distributed_init_method = self.distributed_init_method
|
||||
|
||||
if torch.distributed.is_initialized():
|
||||
torch_world_size = torch.distributed.get_world_size()
|
||||
if torch_world_size != parallel_config.world_size:
|
||||
raise RuntimeError(
|
||||
"torch.distributed is already initialized but the torch "
|
||||
"world size does not match parallel_config.world_size "
|
||||
f"({torch_world_size} vs. {parallel_config.world_size}).")
|
||||
elif not distributed_init_method:
|
||||
raise ValueError(
|
||||
"distributed_init_method must be set if torch.distributed "
|
||||
"is not already initialized")
|
||||
else:
|
||||
# use sockets as default Level zero IPC exchange backend. By
|
||||
# default oneccl will use `drmfd` as mechanism which need extra
|
||||
# dependency (libdrm and drm headers) on your system.
|
||||
ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi")
|
||||
ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE",
|
||||
str(parallel_config.world_size))
|
||||
os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT
|
||||
os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE
|
||||
os.environ["LOCAL_RANK"] = str(self.local_rank)
|
||||
init_distributed_environment(
|
||||
world_size=parallel_config.world_size,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
local_rank=self.local_rank,
|
||||
backend="ccl")
|
||||
|
||||
ensure_model_parallel_initialized(
|
||||
parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
# global all_reduce needed for overall oneccl warm up
|
||||
torch.distributed.all_reduce(torch.zeros(1).xpu())
|
||||
|
||||
if parallel_config.pipeline_parallel_size > 1:
|
||||
# Add pp group init to avoid
|
||||
# p2p communication as the first call
|
||||
get_pp_group().all_reduce(torch.zeros(1).xpu())
|
Loading…
Reference in New Issue