mirror of https://github.com/vllm-project/vllm.git
Compare commits
1 Commits
Author | SHA1 | Date |
---|---|---|
|
a5dd03c1eb |
|
@ -66,10 +66,10 @@ function cpu_tests() {
|
||||||
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynamic_per_token"
|
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynamic_per_token"
|
||||||
|
|
||||||
# Run AWQ test
|
# Run AWQ test
|
||||||
# docker exec cpu-test-"$NUMA_NODE" bash -c "
|
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||||
# set -e
|
set -e
|
||||||
# VLLM_USE_V1=0 pytest -s -v \
|
VLLM_USE_V1=0 pytest -s -v \
|
||||||
# tests/quantization/test_ipex_quant.py"
|
tests/quantization/test_ipex_quant.py"
|
||||||
|
|
||||||
# Run chunked-prefill and prefix-cache test
|
# Run chunked-prefill and prefix-cache test
|
||||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||||
|
|
|
@ -26,5 +26,7 @@ docker run \
|
||||||
--name "${container_name}" \
|
--name "${container_name}" \
|
||||||
"${image_name}" \
|
"${image_name}" \
|
||||||
sh -c '
|
sh -c '
|
||||||
|
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m
|
||||||
|
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m -tp 2
|
||||||
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
|
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
|
||||||
'
|
'
|
||||||
|
|
|
@ -8,7 +8,7 @@ image:
|
||||||
# -- Image tag
|
# -- Image tag
|
||||||
tag: "latest"
|
tag: "latest"
|
||||||
# -- Container launch command
|
# -- Container launch command
|
||||||
command: ["vllm", "serve", "/data/", "--served-model-name", "opt-125m", "--enforce-eager", "--dtype", "bfloat16", "--block-size", "16", "--host", "0.0.0.0", "--port", "8000"]
|
command: ["vllm", "serve", "/data/", "--served-model-name", "opt-125m", "--dtype", "float32", "--block-size", "16", "--host", "0.0.0.0", "--port", "8000"]
|
||||||
|
|
||||||
# -- Container port
|
# -- Container port
|
||||||
containerPort: 8000
|
containerPort: 8000
|
||||||
|
|
|
@ -36,8 +36,7 @@ DEVICE_REGULAR_ATTN_BACKENDS = {
|
||||||
DEVICE_MLA_BLOCK_SIZES = {
|
DEVICE_MLA_BLOCK_SIZES = {
|
||||||
"cuda": [16, 64], # CUDA supports both standard and extended block sizes
|
"cuda": [16, 64], # CUDA supports both standard and extended block sizes
|
||||||
"hip": [16, 1], # HIP requires special handling for block_size=1
|
"hip": [16, 1], # HIP requires special handling for block_size=1
|
||||||
# "cpu": [16] # CPU uses fixed block size from test cases
|
"cpu": [16] # CPU uses fixed block size from test cases
|
||||||
"cpu": [] # FIXME(woosuk): Temporarily disable CPU tests
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -82,14 +81,14 @@ def test_env(
|
||||||
m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")
|
m.setenv("VLLM_MLA_DISABLE", "1" if use_mla else "0")
|
||||||
|
|
||||||
if device == "cpu":
|
if device == "cpu":
|
||||||
if not use_v1:
|
|
||||||
pytest.skip("CPU backend only supports V1")
|
|
||||||
|
|
||||||
with patch("vllm.attention.selector.current_platform",
|
with patch("vllm.attention.selector.current_platform",
|
||||||
CpuPlatform()):
|
CpuPlatform()):
|
||||||
backend = get_attn_backend(16, torch.float16, torch.float16,
|
backend = get_attn_backend(16, torch.float16, torch.float16,
|
||||||
block_size, False)
|
block_size, False)
|
||||||
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
|
if use_v1:
|
||||||
|
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
|
||||||
|
else:
|
||||||
|
assert backend.get_name() == "TORCH_SDPA"
|
||||||
|
|
||||||
elif device == "hip":
|
elif device == "hip":
|
||||||
with patch("vllm.attention.selector.current_platform",
|
with patch("vllm.attention.selector.current_platform",
|
||||||
|
@ -205,14 +204,12 @@ def test_fp32_fallback(
|
||||||
m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
|
m.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
|
||||||
|
|
||||||
if device == "cpu":
|
if device == "cpu":
|
||||||
if not use_v1:
|
|
||||||
pytest.skip("CPU backend only supports V1")
|
|
||||||
|
|
||||||
with patch("vllm.attention.selector.current_platform",
|
with patch("vllm.attention.selector.current_platform",
|
||||||
CpuPlatform()):
|
CpuPlatform()):
|
||||||
backend = get_attn_backend(16, torch.float32, torch.float32,
|
backend = get_attn_backend(16, torch.float32, torch.float32,
|
||||||
16, False)
|
16, False)
|
||||||
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"
|
assert (backend.get_name() == "TORCH_SDPA_VLLM_V1"
|
||||||
|
if use_v1 else "TORCH_SDPA")
|
||||||
|
|
||||||
elif device == "cuda":
|
elif device == "cuda":
|
||||||
with patch("vllm.attention.selector.current_platform",
|
with patch("vllm.attention.selector.current_platform",
|
||||||
|
|
|
@ -0,0 +1,307 @@
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
import vllm._custom_ops as ops
|
||||||
|
from vllm._ipex_ops import ipex_ops
|
||||||
|
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||||
|
AttentionMetadataBuilder,
|
||||||
|
AttentionType,
|
||||||
|
is_quantized_kv_cache)
|
||||||
|
from vllm.attention.backends.mla.common import MLACommonImpl, MLACommonState
|
||||||
|
from vllm.attention.backends.torch_sdpa import TorchSDPAMetadata
|
||||||
|
from vllm.utils import make_tensor_with_pad
|
||||||
|
from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder
|
||||||
|
|
||||||
|
|
||||||
|
class CPUMLABackend(AttentionBackend):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_name() -> str:
|
||||||
|
return "CPU_MLA"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_metadata_cls() -> Type["CPUMLAMetadata"]:
|
||||||
|
return CPUMLAMetadata
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_builder_cls() -> Type["CPUMLAMetadataBuilder"]:
|
||||||
|
return CPUMLAMetadataBuilder
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_state_cls() -> Type["MLACommonState"]:
|
||||||
|
return MLACommonState
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_impl_cls() -> Type["CPUMLAImpl"]:
|
||||||
|
return CPUMLAImpl
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_kv_cache_shape(
|
||||||
|
num_blocks: int,
|
||||||
|
block_size: int,
|
||||||
|
num_kv_heads: int, # assumed to be 1 for MLA
|
||||||
|
head_size: int,
|
||||||
|
) -> Tuple[int, ...]:
|
||||||
|
return (num_blocks, block_size, head_size)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def swap_blocks(
|
||||||
|
src_kv_cache: torch.Tensor,
|
||||||
|
dst_kv_cache: torch.Tensor,
|
||||||
|
src_to_dst: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def copy_blocks(
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
src_to_dists: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
ops.copy_blocks_mla(kv_caches, src_to_dists)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_supported_head_sizes() -> List[int]:
|
||||||
|
return [576]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CPUMLAMetadata(TorchSDPAMetadata):
|
||||||
|
# New for MLA
|
||||||
|
# Input positions for rotrary embeddings since for MLA the rotary
|
||||||
|
# position embeddings are applied inside the attention backend
|
||||||
|
input_positions: torch.Tensor = None
|
||||||
|
|
||||||
|
# required by MLACommonImpl
|
||||||
|
is_profile_run: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class CPUMLAMetadataBuilder(AttentionMetadataBuilder[CPUMLAMetadata]):
|
||||||
|
|
||||||
|
def __init__(self, input_builder: ModelInputForCPUBuilder) -> None:
|
||||||
|
self.chunked_prefill = input_builder.chunked_prefill
|
||||||
|
self.input_builder = input_builder
|
||||||
|
assert not self.chunked_prefill, \
|
||||||
|
"chunked prefill is currently not supported"
|
||||||
|
|
||||||
|
def prepare(self):
|
||||||
|
self.input_data = self.input_builder.input_data
|
||||||
|
|
||||||
|
def build(self, seq_lens, query_lens, cuda_graph_pad_size, batch_size):
|
||||||
|
input_data = self.input_data
|
||||||
|
prefill_seq_lens = seq_lens[0:input_data.num_prefills]
|
||||||
|
prefill_query_lens = query_lens[0:input_data.num_prefills]
|
||||||
|
slot_mapping = torch.tensor(input_data.slot_mapping,
|
||||||
|
dtype=torch.long,
|
||||||
|
device="cpu")
|
||||||
|
|
||||||
|
# metadata for prefill
|
||||||
|
if input_data.num_prefills > 0:
|
||||||
|
query_lens_tensor = torch.tensor(prefill_query_lens,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu")
|
||||||
|
kv_lens_tensor = torch.tensor(prefill_seq_lens,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu")
|
||||||
|
query_start_loc = torch.zeros(input_data.num_prefills + 1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu")
|
||||||
|
kv_start_loc = torch.zeros(input_data.num_prefills + 1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu")
|
||||||
|
torch.cumsum(query_lens_tensor,
|
||||||
|
dim=0,
|
||||||
|
dtype=torch.int32,
|
||||||
|
out=query_start_loc[1:])
|
||||||
|
torch.cumsum(kv_lens_tensor,
|
||||||
|
dim=0,
|
||||||
|
dtype=torch.int32,
|
||||||
|
out=kv_start_loc[1:])
|
||||||
|
max_query_len = max(prefill_query_lens)
|
||||||
|
max_kv_len = max(prefill_seq_lens)
|
||||||
|
|
||||||
|
# for chunked-prefill
|
||||||
|
if self.chunked_prefill:
|
||||||
|
prefill_block_tables = make_tensor_with_pad(
|
||||||
|
self.input_data.prefill_block_tables,
|
||||||
|
pad=0,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
prefill_block_tables = None
|
||||||
|
|
||||||
|
else:
|
||||||
|
query_start_loc = None
|
||||||
|
kv_start_loc = None
|
||||||
|
max_query_len = None
|
||||||
|
max_kv_len = None
|
||||||
|
prefill_block_tables = None
|
||||||
|
|
||||||
|
# metadata for decode
|
||||||
|
if input_data.num_decode_tokens != 0:
|
||||||
|
seq_lens_tensor = torch.tensor(
|
||||||
|
input_data.seq_lens[input_data.num_prefills:],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu",
|
||||||
|
)
|
||||||
|
block_tables = make_tensor_with_pad(
|
||||||
|
self.input_data.decode_block_tables,
|
||||||
|
pad=0,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
block_tables = torch.tensor([])
|
||||||
|
seq_lens_tensor = torch.tensor(
|
||||||
|
input_data.seq_lens[:input_data.num_prefills],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu",
|
||||||
|
)
|
||||||
|
|
||||||
|
# For multi-modal models
|
||||||
|
placeholder_index_maps = None
|
||||||
|
if len(input_data.multi_modal_inputs_list) != 0:
|
||||||
|
placeholder_index_maps = {
|
||||||
|
modality: placeholder_map.index_map()
|
||||||
|
for modality, placeholder_map in
|
||||||
|
input_data.multi_modal_placeholder_maps.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
return CPUMLAMetadata(
|
||||||
|
chunked_prefill=self.chunked_prefill,
|
||||||
|
seq_lens=prefill_seq_lens,
|
||||||
|
seq_lens_tensor=seq_lens_tensor,
|
||||||
|
max_query_len=max_query_len,
|
||||||
|
max_kv_len=max_kv_len,
|
||||||
|
prefill_query_start_loc=query_start_loc,
|
||||||
|
kv_start_loc=kv_start_loc,
|
||||||
|
max_decode_seq_len=input_data.max_decode_seq_len,
|
||||||
|
num_prefills=input_data.num_prefills,
|
||||||
|
num_prefill_tokens=input_data.num_prefill_tokens,
|
||||||
|
num_decode_tokens=input_data.num_decode_tokens,
|
||||||
|
block_tables=block_tables,
|
||||||
|
prefill_block_tables=prefill_block_tables,
|
||||||
|
slot_mapping=slot_mapping,
|
||||||
|
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||||
|
enable_kv_scales_calculation=False,
|
||||||
|
input_positions=torch.tensor([self.input_data.input_positions]))
|
||||||
|
|
||||||
|
|
||||||
|
class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
scale: float,
|
||||||
|
num_kv_heads: int,
|
||||||
|
alibi_slopes: Optional[List[float]],
|
||||||
|
sliding_window: Optional[int],
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
blocksparse_params: Optional[Dict[str, Any]],
|
||||||
|
logits_soft_cap: Optional[float],
|
||||||
|
attn_type: str,
|
||||||
|
kv_sharing_target_layer_name: Optional[str],
|
||||||
|
# MLA Specific Arguments
|
||||||
|
**mla_args) -> None:
|
||||||
|
super().__init__(num_heads, head_size, scale, num_kv_heads,
|
||||||
|
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||||
|
blocksparse_params, logits_soft_cap, attn_type,
|
||||||
|
kv_sharing_target_layer_name, **mla_args)
|
||||||
|
|
||||||
|
unsupported_features = [
|
||||||
|
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
|
||||||
|
]
|
||||||
|
if any(unsupported_features):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"CPUMLAImpl does not support one of the following: "
|
||||||
|
"alibi_slopes, sliding_window, blocksparse_params, "
|
||||||
|
"logits_soft_cap")
|
||||||
|
|
||||||
|
if attn_type != AttentionType.DECODER:
|
||||||
|
raise NotImplementedError("Encoder self-attention and "
|
||||||
|
"encoder/decoder cross-attention "
|
||||||
|
"are not implemented for "
|
||||||
|
"CPUMLAImpl")
|
||||||
|
|
||||||
|
# states is implemented.
|
||||||
|
if is_quantized_kv_cache(self.kv_cache_dtype):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"CPUMLAImpl with FP8 KV cache not yet supported")
|
||||||
|
|
||||||
|
def _forward_prefill(
|
||||||
|
self,
|
||||||
|
q: torch.Tensor,
|
||||||
|
kv_c_normed: torch.Tensor,
|
||||||
|
k_pe: torch.Tensor,
|
||||||
|
kv_c_and_k_pe_cache: torch.Tensor,
|
||||||
|
attn_metadata: CPUMLAMetadata, # type: ignore[override]
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
prefill_metadata = attn_metadata.prefill_metadata
|
||||||
|
assert prefill_metadata is not None
|
||||||
|
|
||||||
|
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
|
||||||
|
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||||
|
k_nope, v = kv_nope\
|
||||||
|
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||||
|
|
||||||
|
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
|
||||||
|
|
||||||
|
# For MLA the v head dim is smaller than qk head dim so we pad out
|
||||||
|
# v with 0s to match the qk head dim
|
||||||
|
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
|
||||||
|
value=0)
|
||||||
|
|
||||||
|
output = torch.empty_like(q)
|
||||||
|
ipex_ops.varlen_attention(
|
||||||
|
query=q,
|
||||||
|
key=k,
|
||||||
|
value=v_padded,
|
||||||
|
out=output,
|
||||||
|
seqlen_q=prefill_metadata.prefill_query_start_loc,
|
||||||
|
seqlen_k=prefill_metadata.prefill_query_start_loc,
|
||||||
|
max_seqlen_q=prefill_metadata.max_query_len,
|
||||||
|
max_seqlen_k=prefill_metadata.max_query_len,
|
||||||
|
pdropout=0.0,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
zero_tensors=False,
|
||||||
|
is_causal=True,
|
||||||
|
return_softmax=False,
|
||||||
|
gen_=None,
|
||||||
|
logits_soft_cap=0.0,
|
||||||
|
window_size_left=-1,
|
||||||
|
window_size_right=-1,
|
||||||
|
alibi_slopes=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# remove padding
|
||||||
|
output = output.view(-1, self.num_heads,
|
||||||
|
q.shape[-1])[..., :v.shape[-1]]
|
||||||
|
return output.reshape(-1, self.num_heads * v.shape[-1])
|
||||||
|
|
||||||
|
def _forward_decode(
|
||||||
|
self,
|
||||||
|
q_nope: torch.Tensor,
|
||||||
|
q_pe: torch.Tensor,
|
||||||
|
kv_c_and_k_pe_cache: torch.Tensor,
|
||||||
|
attn_metadata: CPUMLAMetadata, # type: ignore[override]
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert kv_c_and_k_pe_cache.numel() > 0
|
||||||
|
|
||||||
|
decode_meta = attn_metadata.decode_metadata
|
||||||
|
assert decode_meta is not None
|
||||||
|
|
||||||
|
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||||
|
o = q.new_empty(q.shape[0], self.num_heads, self.kv_lora_rank)
|
||||||
|
|
||||||
|
# Run MQA
|
||||||
|
ops.mla_decode_kvcache_cpu(o, q, kv_c_and_k_pe_cache, self.scale,
|
||||||
|
decode_meta.block_tables,
|
||||||
|
decode_meta.seq_lens_tensor)
|
||||||
|
return self._v_up_proj(o)
|
|
@ -0,0 +1,403 @@
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
""" Attention layer with torch scaled_dot_product_attention
|
||||||
|
and PagedAttention."""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm._ipex_ops import ipex_ops
|
||||||
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
|
AttentionLayer,
|
||||||
|
AttentionMetadata, AttentionType,
|
||||||
|
is_quantized_kv_cache)
|
||||||
|
from vllm.attention.backends.utils import CommonAttentionState
|
||||||
|
from vllm.attention.ops.paged_attn import (PagedAttention,
|
||||||
|
PagedAttentionMetadata)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
_PARTITION_SIZE = 512
|
||||||
|
|
||||||
|
|
||||||
|
class IpexAttnBackend(AttentionBackend):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_name() -> str:
|
||||||
|
return "IPEX"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_impl_cls() -> Type["IpexAttnBackendImpl"]:
|
||||||
|
return IpexAttnBackendImpl
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_metadata_cls() -> Type["IpexAttnMetadata"]:
|
||||||
|
return IpexAttnMetadata
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||||
|
return CommonAttentionState
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_kv_cache_shape(
|
||||||
|
num_blocks: int,
|
||||||
|
block_size: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
) -> Tuple[int, ...]:
|
||||||
|
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||||
|
num_kv_heads, head_size)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def swap_blocks(
|
||||||
|
src_kv_cache: torch.Tensor,
|
||||||
|
dst_kv_cache: torch.Tensor,
|
||||||
|
src_to_dst: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
from vllm._ipex_ops import ipex_ops as ops
|
||||||
|
ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def copy_blocks(
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
src_to_dists: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
from vllm._ipex_ops import ipex_ops as ops
|
||||||
|
key_caches = [kv_cache[0] for kv_cache in kv_caches]
|
||||||
|
value_caches = [kv_cache[1] for kv_cache in kv_caches]
|
||||||
|
ops.copy_blocks(key_caches, value_caches, src_to_dists)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class IpexAttnMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||||
|
"""Metadata for IpexAttnBackend.
|
||||||
|
"""
|
||||||
|
# Currently, input sequences can only contain all prompts
|
||||||
|
# or all decoding. True if all sequences are prompts.
|
||||||
|
is_prompt: bool
|
||||||
|
slot_mapping: torch.Tensor
|
||||||
|
seq_lens: Optional[List[int]]
|
||||||
|
seqlen_q: Optional[torch.Tensor]
|
||||||
|
max_seqlen: Optional[int]
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
# Set during the execution of the first attention op.
|
||||||
|
# It is a list because it is needed to set per prompt
|
||||||
|
# when alibi slopes is used. It is because of the limitation
|
||||||
|
# from xformer API.
|
||||||
|
# will not appear in the __repr__ and __init__
|
||||||
|
self.attn_bias: Optional[List[torch.Tensor]] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prefill_metadata(self) -> Optional["IpexAttnMetadata"]:
|
||||||
|
# Currently chunked prefill is not supported
|
||||||
|
if self.num_decode_tokens == 0:
|
||||||
|
assert self.num_prefills > 0
|
||||||
|
return self
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def decode_metadata(self) -> Optional["IpexAttnMetadata"]:
|
||||||
|
# Currently chunked prefill is not supported
|
||||||
|
if self.num_prefills > 0:
|
||||||
|
assert self.num_decode_tokens == 0
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
scale: float,
|
||||||
|
num_kv_heads: int,
|
||||||
|
alibi_slopes: Optional[List[float]],
|
||||||
|
sliding_window: Optional[int],
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
|
logits_soft_cap: Optional[float] = None,
|
||||||
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
kv_sharing_target_layer_name: Optional[str] = None,
|
||||||
|
use_irope: bool = False,
|
||||||
|
) -> None:
|
||||||
|
if kv_sharing_target_layer_name is not None:
|
||||||
|
raise NotImplementedError("KV sharing is not supported in V0.")
|
||||||
|
if use_irope:
|
||||||
|
logger.warning_once(
|
||||||
|
"Using irope in Ipex is not supported yet, it will fall"
|
||||||
|
" back to global attention for long context.")
|
||||||
|
if blocksparse_params is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"IPEX backend does not support block-sparse attention.")
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_size = head_size
|
||||||
|
self.scale = float(scale)
|
||||||
|
self.num_kv_heads = num_kv_heads
|
||||||
|
if alibi_slopes is not None:
|
||||||
|
alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32)
|
||||||
|
self.alibi_slopes = alibi_slopes
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
|
|
||||||
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
self.need_mask = (self.sliding_window is not None)
|
||||||
|
if logits_soft_cap is None:
|
||||||
|
logits_soft_cap = -1
|
||||||
|
self.logits_soft_cap = logits_soft_cap
|
||||||
|
|
||||||
|
supported_head_sizes = PagedAttention.get_supported_head_sizes()
|
||||||
|
if head_size not in supported_head_sizes:
|
||||||
|
raise ValueError(
|
||||||
|
f"Head size {head_size} is not supported by PagedAttention. "
|
||||||
|
f"Supported head sizes are: {supported_head_sizes}.")
|
||||||
|
if is_quantized_kv_cache(kv_cache_dtype):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"IPEX backend does not support FP8 KV cache. "
|
||||||
|
"Please use xFormers backend instead.")
|
||||||
|
if attn_type != AttentionType.DECODER:
|
||||||
|
raise NotImplementedError("Encoder self-attention and "
|
||||||
|
"encoder/decoder cross-attention "
|
||||||
|
"are not implemented for "
|
||||||
|
"IpexAttnBackendImpl")
|
||||||
|
|
||||||
|
def split_kv_cache(
|
||||||
|
self,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
x = 1
|
||||||
|
num_blocks = kv_cache.shape[1]
|
||||||
|
|
||||||
|
key_cache = kv_cache[0]
|
||||||
|
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
|
||||||
|
-1, x)
|
||||||
|
value_cache = kv_cache[1]
|
||||||
|
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
|
||||||
|
return key_cache, value_cache
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
layer: AttentionLayer,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: torch.Tensor,
|
||||||
|
attn_metadata: IpexAttnMetadata, # type: ignore
|
||||||
|
output: Optional[torch.Tensor] = None,
|
||||||
|
output_scale: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Forward pass with IPEX varlen_attention and PagedAttention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: shape = [num_tokens, num_heads * head_size]
|
||||||
|
key: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
|
value: shape = [num_tokens, num_kv_heads * head_size]
|
||||||
|
kv_cache = [2, num_blocks, block_size * num_kv_heads * head_size]
|
||||||
|
NOTE: kv_cache will be an empty tensor with shape [0]
|
||||||
|
for profiling run.
|
||||||
|
attn_metadata: Metadata for attention.
|
||||||
|
Returns:
|
||||||
|
shape = [num_tokens, num_heads * head_size]
|
||||||
|
"""
|
||||||
|
if output_scale is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"fused output quantization is not yet supported"
|
||||||
|
" for IpexAttentionImpl")
|
||||||
|
|
||||||
|
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||||
|
num_tokens, hidden_size = query.shape
|
||||||
|
# Reshape the query, key, and value tensors.
|
||||||
|
query = query.view(-1, self.num_heads, self.head_size)
|
||||||
|
key = key.view(-1, self.num_kv_heads, self.head_size)
|
||||||
|
value = value.view(-1, self.num_kv_heads, self.head_size)
|
||||||
|
|
||||||
|
if kv_cache.numel() > 0:
|
||||||
|
key_cache, value_cache = self.split_kv_cache(
|
||||||
|
kv_cache, self.num_kv_heads, self.head_size)
|
||||||
|
ipex_ops.reshape_and_cache(
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
attn_metadata.slot_mapping.flatten(),
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
layer._k_scale_float,
|
||||||
|
layer._v_scale_float,
|
||||||
|
)
|
||||||
|
|
||||||
|
if attn_metadata.is_prompt:
|
||||||
|
assert attn_metadata.seq_lens is not None
|
||||||
|
if (kv_cache.numel() == 0
|
||||||
|
or attn_metadata.block_tables.numel() == 0):
|
||||||
|
if self.num_kv_heads != self.num_heads:
|
||||||
|
key = key.repeat_interleave(self.num_queries_per_kv, dim=1)
|
||||||
|
value = value.repeat_interleave(self.num_queries_per_kv,
|
||||||
|
dim=1)
|
||||||
|
|
||||||
|
if attn_metadata.attn_bias is None:
|
||||||
|
if self.sliding_window is not None:
|
||||||
|
att_masks = _make_sliding_window_bias(
|
||||||
|
attn_metadata.seq_lens, self.sliding_window,
|
||||||
|
query.dtype) # type: ignore
|
||||||
|
else:
|
||||||
|
att_masks = _make_sliding_window_bias(
|
||||||
|
attn_metadata.seq_lens, None, dtype=query.dtype)
|
||||||
|
attn_metadata.attn_bias = att_masks
|
||||||
|
|
||||||
|
output = torch.empty(
|
||||||
|
(num_tokens, self.num_heads, self.head_size),
|
||||||
|
dtype=query.dtype,
|
||||||
|
device=query.device)
|
||||||
|
ipex_ops.varlen_attention(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
output,
|
||||||
|
attn_metadata.seqlen_q,
|
||||||
|
attn_metadata.seqlen_q,
|
||||||
|
self.alibi_slopes,
|
||||||
|
attn_metadata.max_seqlen,
|
||||||
|
attn_metadata.max_seqlen,
|
||||||
|
pdropout=0.0,
|
||||||
|
softmax_scale=self.scale,
|
||||||
|
zero_tensors=False,
|
||||||
|
is_causal=True,
|
||||||
|
return_softmax=False,
|
||||||
|
gen_=None,
|
||||||
|
window_size_left=-1,
|
||||||
|
window_size_right=-1,
|
||||||
|
logits_soft_cap=self.logits_soft_cap,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# prefix-enabled attention
|
||||||
|
raise RuntimeError(
|
||||||
|
"IPEX backend doesn't support prefix decoding.")
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Decoding run.
|
||||||
|
max_seq_len = attn_metadata.max_decode_seq_len
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
block_size = value_cache.shape[3]
|
||||||
|
num_seqs, num_heads, head_size = query.shape
|
||||||
|
max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) //
|
||||||
|
_PARTITION_SIZE)
|
||||||
|
# NOTE(woosuk): We use a simple heuristic to decide whether to use
|
||||||
|
# PagedAttention V1 or V2. If the number of partitions is 1, we use
|
||||||
|
# V1 to avoid the overhead of reduction. Also, if the number of
|
||||||
|
# sequences or heads is large, we use V1 since there is enough work
|
||||||
|
# to parallelize.
|
||||||
|
# TODO(woosuk): Tune this heuristic.
|
||||||
|
# For context len > 8192, use V2 kernel to avoid shared memory
|
||||||
|
# shortage.
|
||||||
|
use_v1 = (max_seq_len <= 8192 and
|
||||||
|
(max_num_partitions == 1 or num_seqs * num_heads > 512))
|
||||||
|
if use_v1:
|
||||||
|
# Run PagedAttention V1.
|
||||||
|
ipex_ops.paged_attention_v1(
|
||||||
|
output,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
self.num_kv_heads,
|
||||||
|
self.scale,
|
||||||
|
attn_metadata.block_tables,
|
||||||
|
attn_metadata.seq_lens_tensor,
|
||||||
|
block_size,
|
||||||
|
max_seq_len,
|
||||||
|
self.alibi_slopes,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
layer._k_scale_float,
|
||||||
|
layer._v_scale_float,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Run PagedAttention V2.
|
||||||
|
assert _PARTITION_SIZE % block_size == 0
|
||||||
|
tmp_output = torch.empty(
|
||||||
|
size=(num_seqs, num_heads, max_num_partitions, head_size),
|
||||||
|
dtype=output.dtype,
|
||||||
|
device=output.device,
|
||||||
|
)
|
||||||
|
exp_sums = torch.empty(
|
||||||
|
size=(num_seqs, num_heads, max_num_partitions),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=output.device,
|
||||||
|
)
|
||||||
|
max_logits = torch.empty_like(exp_sums)
|
||||||
|
ipex_ops.paged_attention_v2(
|
||||||
|
output,
|
||||||
|
exp_sums,
|
||||||
|
max_logits,
|
||||||
|
tmp_output,
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
self.num_kv_heads,
|
||||||
|
self.scale,
|
||||||
|
attn_metadata.block_tables,
|
||||||
|
attn_metadata.seq_lens_tensor,
|
||||||
|
block_size,
|
||||||
|
max_seq_len,
|
||||||
|
self.alibi_slopes,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
layer._k_scale_float,
|
||||||
|
layer._v_scale_float,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reshape the output tensor.
|
||||||
|
return output.view(-1, self.num_heads * self.head_size)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_alibi_bias(
|
||||||
|
alibi_slopes: torch.Tensor,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
seq_lens: List[int],
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
attn_biases = []
|
||||||
|
for seq_len in seq_lens:
|
||||||
|
bias = torch.arange(seq_len, dtype=dtype, device=alibi_slopes.device)
|
||||||
|
# NOTE(zhuohan): HF uses
|
||||||
|
# `bias = bias[None, :].repeat(seq_len, 1)`
|
||||||
|
# here. We find that both biases give the same results, but
|
||||||
|
# the bias below more accurately follows the original ALiBi
|
||||||
|
# paper.
|
||||||
|
bias = bias[None, :] - bias[:, None]
|
||||||
|
|
||||||
|
num_heads = alibi_slopes.shape[0]
|
||||||
|
bias = bias[None, :].repeat((num_heads, 1, 1))
|
||||||
|
bias.mul_(alibi_slopes[:, None, None])
|
||||||
|
inf_mask = torch.empty(
|
||||||
|
(1, seq_len, seq_len),
|
||||||
|
dtype=bias.dtype,
|
||||||
|
device=alibi_slopes.device).fill_(-torch.inf).triu_(diagonal=1)
|
||||||
|
attn_biases.append((bias + inf_mask).to(dtype))
|
||||||
|
|
||||||
|
return attn_biases
|
||||||
|
|
||||||
|
|
||||||
|
def _make_sliding_window_bias(
|
||||||
|
seq_lens: List[int],
|
||||||
|
window_size: Optional[int],
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
attn_biases = []
|
||||||
|
for seq_len in seq_lens:
|
||||||
|
tensor = torch.full(
|
||||||
|
(1, seq_len, seq_len),
|
||||||
|
dtype=dtype,
|
||||||
|
fill_value=1,
|
||||||
|
)
|
||||||
|
shift = 0
|
||||||
|
mask = torch.tril(tensor, diagonal=shift).to(dtype) # type: ignore
|
||||||
|
if window_size is not None:
|
||||||
|
mask = torch.triu(mask, diagonal=shift - window_size + 1)
|
||||||
|
mask = torch.log(mask)
|
||||||
|
attn_biases.append(mask.to(dtype))
|
||||||
|
|
||||||
|
return attn_biases
|
|
@ -0,0 +1,356 @@
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch_xla.experimental.custom_kernel # Required to register custom ops.
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
|
AttentionLayer,
|
||||||
|
AttentionMetadata, AttentionType,
|
||||||
|
is_quantized_kv_cache)
|
||||||
|
from vllm.attention.backends.utils import CommonAttentionState
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class PallasAttentionBackend(AttentionBackend):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_name() -> str:
|
||||||
|
return "PALLAS"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_impl_cls() -> Type["PallasAttentionBackendImpl"]:
|
||||||
|
return PallasAttentionBackendImpl
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_metadata_cls() -> Type["PallasMetadata"]:
|
||||||
|
return PallasMetadata
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||||
|
return CommonAttentionState
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_kv_cache_shape(
|
||||||
|
num_blocks: int,
|
||||||
|
block_size: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
) -> Tuple[int, ...]:
|
||||||
|
return (num_kv_heads, num_blocks, block_size, head_size)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def swap_blocks(
|
||||||
|
src_kv_cache: torch.Tensor,
|
||||||
|
dst_kv_cache: torch.Tensor,
|
||||||
|
src_to_dst: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
raise RuntimeError("swap_blocks is not used for the TPU backend.")
|
||||||
|
|
||||||
|
@torch.compile(backend="openxla")
|
||||||
|
@staticmethod
|
||||||
|
def copy_blocks(
|
||||||
|
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
src_to_dists: Tuple[torch.Tensor, torch.Tensor],
|
||||||
|
) -> None:
|
||||||
|
src_indices, dst_indices = src_to_dists
|
||||||
|
for k_cache, v_cache in kv_caches:
|
||||||
|
torch.ops.xla.dynamo_set_buffer_donor_(k_cache, True)
|
||||||
|
k_cache[:, dst_indices] = k_cache[:, src_indices]
|
||||||
|
torch.ops.xla.dynamo_set_buffer_donor_(v_cache, True)
|
||||||
|
v_cache[:, dst_indices] = v_cache[:, src_indices]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PallasMetadata(AttentionMetadata):
|
||||||
|
|
||||||
|
# Currently, input sequences can only contain all prefills
|
||||||
|
# or all decoding.
|
||||||
|
block_tables: Optional[torch.Tensor] = None
|
||||||
|
context_lens: Optional[torch.Tensor] = None
|
||||||
|
effective_query_lens: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def prefill_metadata(self) -> Optional["PallasMetadata"]:
|
||||||
|
if self.num_prefills == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
assert self.num_decode_tokens == 0
|
||||||
|
return self
|
||||||
|
|
||||||
|
@property
|
||||||
|
def decode_metadata(self) -> Optional["PallasMetadata"]:
|
||||||
|
if self.num_decode_tokens == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
assert self.num_prefills == 0
|
||||||
|
assert self.num_prefill_tokens == 0
|
||||||
|
assert self.block_tables is not None
|
||||||
|
assert self.context_lens is not None
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class PallasAttentionBackendImpl(AttentionImpl):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
scale: float,
|
||||||
|
num_kv_heads: int,
|
||||||
|
alibi_slopes: Optional[List[float]],
|
||||||
|
sliding_window: Optional[int],
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
blocksparse_params: Optional[Dict[str, Any]] = None,
|
||||||
|
logits_soft_cap: Optional[float] = None,
|
||||||
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
kv_sharing_target_layer_name: Optional[str] = None,
|
||||||
|
use_irope: bool = False,
|
||||||
|
) -> None:
|
||||||
|
if kv_sharing_target_layer_name is not None:
|
||||||
|
raise NotImplementedError("KV sharing is not supported in V0.")
|
||||||
|
if use_irope:
|
||||||
|
logger.warning_once(
|
||||||
|
"Using irope in Pallas is not supported yet, it will fall back "
|
||||||
|
"to global attention for long context.")
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_size = head_size
|
||||||
|
self.scale = float(scale)
|
||||||
|
self.num_kv_heads = num_kv_heads
|
||||||
|
|
||||||
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
self.logits_soft_cap = logits_soft_cap
|
||||||
|
if head_size % 128 != 0:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Head size must be a multiple of 128, found {head_size}.")
|
||||||
|
if alibi_slopes is not None:
|
||||||
|
raise NotImplementedError("Alibi slopes is not supported.")
|
||||||
|
if sliding_window is not None:
|
||||||
|
raise NotImplementedError("Sliding window is not supported.")
|
||||||
|
if is_quantized_kv_cache(kv_cache_dtype):
|
||||||
|
raise NotImplementedError("FP8 KV cache dtype is not supported.")
|
||||||
|
if blocksparse_params is not None:
|
||||||
|
raise NotImplementedError("Blocksparse is not supported.")
|
||||||
|
|
||||||
|
if torch_xla.tpu.version() < 4:
|
||||||
|
raise NotImplementedError("TPU version must be 4 or higher.")
|
||||||
|
|
||||||
|
self.megacore_mode = None
|
||||||
|
tpu_env = torch_xla.tpu.get_tpu_env()
|
||||||
|
tpu_type = (tpu_env.get("ACCELERATOR_TYPE", None)
|
||||||
|
or tpu_env.get("TYPE", None)
|
||||||
|
or tpu_env.get("TPU_ACCELERATOR_TYPE", None))
|
||||||
|
assert tpu_type is not None
|
||||||
|
tpu_type = tpu_type.lower()
|
||||||
|
|
||||||
|
if (("lite" not in tpu_type) and ("v6" not in tpu_type)):
|
||||||
|
if self.num_kv_heads % 2 == 0:
|
||||||
|
self.megacore_mode = "kv_head"
|
||||||
|
else:
|
||||||
|
# NOTE(woosuk): If the batch size is not a multiple of 2, the
|
||||||
|
# megacore mode will be None.
|
||||||
|
self.megacore_mode = "batch"
|
||||||
|
|
||||||
|
if attn_type != AttentionType.DECODER:
|
||||||
|
raise NotImplementedError("Encoder self-attention and "
|
||||||
|
"encoder/decoder cross-attention "
|
||||||
|
"are not implemented for "
|
||||||
|
"PallasAttentionBackendImpl")
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
layer: AttentionLayer,
|
||||||
|
query: torch.Tensor,
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
kv_cache: Tuple[torch.Tensor, torch.Tensor],
|
||||||
|
attn_metadata: PallasMetadata,
|
||||||
|
output: Optional[torch.Tensor] = None,
|
||||||
|
output_scale: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Forward pass with Pallas attention.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: shape = [batch_size, seq_len, num_heads * head_size]
|
||||||
|
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||||
|
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||||
|
kv_cache[0] = [num_kv_heads, num_blocks, block_size, head_size]
|
||||||
|
kv_cache[1] = [num_kv_heads, num_blocks, block_size, head_size]
|
||||||
|
NOTE: kv_cache[0] and kv_cache[1] will be an empty tensor
|
||||||
|
with shape [0] for profiling run.
|
||||||
|
attn_metadata: Metadata for attention.
|
||||||
|
Returns:
|
||||||
|
shape = [batch_size, seq_len, num_heads * head_size]
|
||||||
|
"""
|
||||||
|
if output_scale is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"fused output quantization is not yet supported"
|
||||||
|
" for PallasAttentionImpl")
|
||||||
|
|
||||||
|
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
|
||||||
|
batch_size, seq_len, hidden_size = query.shape
|
||||||
|
query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
|
||||||
|
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
|
||||||
|
value = value.view(batch_size, seq_len, self.num_kv_heads,
|
||||||
|
self.head_size)
|
||||||
|
|
||||||
|
if kv_cache[0].numel() > 0:
|
||||||
|
slot_mapping = attn_metadata.slot_mapping
|
||||||
|
key_cache, value_cache = kv_cache
|
||||||
|
write_to_kv_cache(key, value, key_cache, value_cache, slot_mapping)
|
||||||
|
|
||||||
|
query = query * self.scale
|
||||||
|
if attn_metadata.num_prefills > 0:
|
||||||
|
if attn_metadata.block_tables is None:
|
||||||
|
# Prefill without paged KV cache.
|
||||||
|
assert seq_len % 16 == 0, (
|
||||||
|
"Pallas FlashAttention kernel requires seq_len to be a "
|
||||||
|
f"multiple of 16 but got {seq_len}")
|
||||||
|
|
||||||
|
# Handle GQA/MQA.
|
||||||
|
if self.num_kv_heads != self.num_heads:
|
||||||
|
key = key.repeat_interleave(self.num_queries_per_kv,
|
||||||
|
dim=-2)
|
||||||
|
key = key.view(batch_size, seq_len, self.num_heads,
|
||||||
|
self.head_size)
|
||||||
|
value = value.repeat_interleave(self.num_queries_per_kv,
|
||||||
|
dim=-2)
|
||||||
|
value = value.view(batch_size, seq_len, self.num_heads,
|
||||||
|
self.head_size)
|
||||||
|
# FlashAttention kernel requires the input shape to be
|
||||||
|
# [batch_size, num_heads, seq_len, d_model]
|
||||||
|
# while the input is [batch_size, seq_len, num_heads, d_model].
|
||||||
|
# Permute the input to match the required format.
|
||||||
|
output = torch.ops.xla.flash_attention(
|
||||||
|
query.permute(0, 2, 1, 3),
|
||||||
|
key.permute(0, 2, 1, 3),
|
||||||
|
value.permute(0, 2, 1, 3),
|
||||||
|
True,
|
||||||
|
)
|
||||||
|
output = output.permute(0, 2, 1, 3)
|
||||||
|
else:
|
||||||
|
# Prefill with paged KV cache.
|
||||||
|
# TODO(woosuk): Tune the below knobs.
|
||||||
|
num_kv_pages_per_compute_block = 16
|
||||||
|
num_queries_per_compute_block = 16
|
||||||
|
assert seq_len % num_queries_per_compute_block == 0
|
||||||
|
output = torch.ops.xla.multi_queries_paged_attention(
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
attn_metadata.context_lens,
|
||||||
|
attn_metadata.block_tables,
|
||||||
|
attn_metadata.effective_query_lens,
|
||||||
|
num_kv_pages_per_compute_block,
|
||||||
|
num_queries_per_compute_block,
|
||||||
|
use_kernel=True,
|
||||||
|
attn_logits_soft_cap=self.logits_soft_cap,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Decoding run.
|
||||||
|
assert kv_cache[0].numel() > 0
|
||||||
|
query = query.squeeze(dim=1)
|
||||||
|
pages_per_compute_block = 16 # TODO(woosuk): Tune this value.
|
||||||
|
|
||||||
|
assert attn_metadata.block_tables is not None
|
||||||
|
assert attn_metadata.context_lens is not None
|
||||||
|
# NOTE(woosuk): The PagedAttention Pallas kernel stores the entire
|
||||||
|
# block table in SMEM. Therefore, if the block table is too large,
|
||||||
|
# the kernel compilation will fail. To avoid this, we split the
|
||||||
|
# batch dimension into smaller chunks and run the kernel multiple
|
||||||
|
# times.
|
||||||
|
MAX_SMEM_USAGE = 512 * 1024
|
||||||
|
size_per_seq = 4 * attn_metadata.block_tables.shape[1]
|
||||||
|
max_num_seq = MAX_SMEM_USAGE // size_per_seq
|
||||||
|
|
||||||
|
if batch_size <= max_num_seq:
|
||||||
|
output = paged_attention(
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
attn_metadata.context_lens,
|
||||||
|
attn_metadata.block_tables,
|
||||||
|
pages_per_compute_block,
|
||||||
|
self.megacore_mode,
|
||||||
|
attn_logits_soft_cap=self.logits_soft_cap,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chunk_size = max_num_seq
|
||||||
|
# Make sure the chunk size is a multiple of 2.
|
||||||
|
chunk_size = chunk_size // 2 * 2
|
||||||
|
num_chunks = (batch_size + chunk_size - 1) // chunk_size
|
||||||
|
|
||||||
|
output = torch.empty_like(query)
|
||||||
|
for chunk_idx in range(num_chunks):
|
||||||
|
chunk_start = chunk_idx * chunk_size
|
||||||
|
chunk_end = chunk_start + chunk_size
|
||||||
|
# NOTE(woosuk): We skip this line because it causes Dynamo
|
||||||
|
# compilation error. Instead, we rely on the slice operation
|
||||||
|
# to handle the out-of-bound case.
|
||||||
|
# chunk_end = min(chunk_end, batch_size)
|
||||||
|
chunk_output = paged_attention(
|
||||||
|
query[chunk_start:chunk_end],
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
attn_metadata.context_lens[chunk_start:chunk_end],
|
||||||
|
attn_metadata.block_tables[chunk_start:chunk_end],
|
||||||
|
pages_per_compute_block,
|
||||||
|
self.megacore_mode,
|
||||||
|
attn_logits_soft_cap=self.logits_soft_cap,
|
||||||
|
)
|
||||||
|
output[chunk_start:chunk_end] = chunk_output
|
||||||
|
|
||||||
|
# Reshape the output tensor.
|
||||||
|
return output.reshape(batch_size, seq_len, hidden_size)
|
||||||
|
|
||||||
|
|
||||||
|
def write_to_kv_cache(
|
||||||
|
key: torch.Tensor,
|
||||||
|
value: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
slot_mapping: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
torch.ops.xla.dynamo_set_buffer_donor_(key_cache, True)
|
||||||
|
torch.ops.xla.dynamo_set_buffer_donor_(value_cache, True)
|
||||||
|
|
||||||
|
key = key.flatten(0, 2)
|
||||||
|
value = value.flatten(0, 2)
|
||||||
|
key_cache = key_cache.flatten(0, 2)
|
||||||
|
value_cache = value_cache.flatten(0, 2)
|
||||||
|
key_cache.index_copy_(0, slot_mapping, key)
|
||||||
|
value_cache.index_copy_(0, slot_mapping, value)
|
||||||
|
|
||||||
|
|
||||||
|
def paged_attention(
|
||||||
|
query: torch.Tensor,
|
||||||
|
key_cache: torch.Tensor,
|
||||||
|
value_cache: torch.Tensor,
|
||||||
|
context_lens: torch.Tensor,
|
||||||
|
block_tables: torch.Tensor,
|
||||||
|
pages_per_compute_block: int,
|
||||||
|
megacore_mode: Optional[str],
|
||||||
|
*,
|
||||||
|
attn_logits_soft_cap: Optional[float],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
batch_size = query.shape[0]
|
||||||
|
if megacore_mode == "batch" and batch_size % 2 != 0:
|
||||||
|
megacore_mode = None
|
||||||
|
else:
|
||||||
|
megacore_mode = megacore_mode
|
||||||
|
|
||||||
|
return torch.ops.xla.paged_attention(
|
||||||
|
query,
|
||||||
|
key_cache,
|
||||||
|
value_cache,
|
||||||
|
context_lens,
|
||||||
|
block_tables,
|
||||||
|
pages_per_compute_block,
|
||||||
|
megacore_mode=megacore_mode,
|
||||||
|
attn_logits_soft_cap=attn_logits_soft_cap,
|
||||||
|
)
|
|
@ -3,24 +3,78 @@
|
||||||
""" Attention layer with torch scaled_dot_product_attention
|
""" Attention layer with torch scaled_dot_product_attention
|
||||||
and PagedAttention."""
|
and PagedAttention."""
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.nn.functional import scaled_dot_product_attention
|
from torch.nn.functional import scaled_dot_product_attention
|
||||||
|
|
||||||
# yapf conflicts with isort for this block
|
# yapf conflicts with isort for this block
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from vllm.attention.backends.abstract import (AttentionImpl, AttentionLayer,
|
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
||||||
AttentionMetadata, AttentionType,
|
AttentionLayer,
|
||||||
|
AttentionMetadata,
|
||||||
|
AttentionMetadataBuilder,
|
||||||
|
AttentionType,
|
||||||
is_quantized_kv_cache)
|
is_quantized_kv_cache)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
|
from vllm.attention.backends.utils import CommonAttentionState
|
||||||
from vllm.attention.ops.ipex_attn import PagedAttention, _use_ipex
|
from vllm.attention.ops.ipex_attn import PagedAttention, _use_ipex
|
||||||
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
|
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
|
from vllm.utils import make_tensor_with_pad
|
||||||
|
from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TorchSDPABackend(AttentionBackend):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_name() -> str:
|
||||||
|
return "TORCH_SDPA"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_impl_cls() -> Type["TorchSDPABackendImpl"]:
|
||||||
|
return TorchSDPABackendImpl
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_metadata_cls() -> Type["AttentionMetadata"]:
|
||||||
|
return TorchSDPAMetadata
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_state_cls() -> Type["CommonAttentionState"]:
|
||||||
|
return CommonAttentionState
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_builder_cls() -> Type["TorchSDPAMetadataBuilder"]:
|
||||||
|
return TorchSDPAMetadataBuilder
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_kv_cache_shape(
|
||||||
|
num_blocks: int,
|
||||||
|
block_size: int,
|
||||||
|
num_kv_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
) -> Tuple[int, ...]:
|
||||||
|
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
|
||||||
|
num_kv_heads, head_size)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def swap_blocks(
|
||||||
|
src_kv_cache: torch.Tensor,
|
||||||
|
dst_kv_cache: torch.Tensor,
|
||||||
|
src_to_dst: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
raise NotImplementedError("Swap is not supported in TorchSDPABackend.")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def copy_blocks(
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
src_to_dists: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
PagedAttention.copy_blocks(kv_caches, src_to_dists)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
|
class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||||
"""Metadata for TorchSDPABackend.
|
"""Metadata for TorchSDPABackend.
|
||||||
|
@ -233,6 +287,113 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata):
|
||||||
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
raise AttributeError(f"Invalid attention type {str(attn_type)}")
|
||||||
|
|
||||||
|
|
||||||
|
class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]):
|
||||||
|
|
||||||
|
def __init__(self, input_builder: ModelInputForCPUBuilder) -> None:
|
||||||
|
self.chunked_prefill = input_builder.chunked_prefill
|
||||||
|
self.input_builder = input_builder
|
||||||
|
|
||||||
|
def prepare(self):
|
||||||
|
self.input_data = self.input_builder.input_data
|
||||||
|
|
||||||
|
def build(self, seq_lens: List[int], query_lens: List[int],
|
||||||
|
cuda_graph_pad_size: int, batch_size: int) -> TorchSDPAMetadata:
|
||||||
|
input_data = self.input_data
|
||||||
|
prefill_seq_lens = seq_lens[0:input_data.num_prefills]
|
||||||
|
prefill_query_lens = query_lens[0:input_data.num_prefills]
|
||||||
|
slot_mapping = torch.tensor(input_data.slot_mapping,
|
||||||
|
dtype=torch.long,
|
||||||
|
device="cpu")
|
||||||
|
|
||||||
|
# For chunked-prefill
|
||||||
|
if self.chunked_prefill and input_data.num_prefill_tokens != 0:
|
||||||
|
prefill_block_tables = make_tensor_with_pad(
|
||||||
|
self.input_data.prefill_block_tables,
|
||||||
|
pad=0,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu",
|
||||||
|
)
|
||||||
|
query_lens_tensor = torch.tensor(prefill_query_lens,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu")
|
||||||
|
kv_lens_tensor = torch.tensor(prefill_seq_lens,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu")
|
||||||
|
query_start_loc = torch.zeros(input_data.num_prefills + 1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu")
|
||||||
|
kv_start_loc = torch.zeros(input_data.num_prefills + 1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu")
|
||||||
|
torch.cumsum(query_lens_tensor,
|
||||||
|
dim=0,
|
||||||
|
dtype=torch.int32,
|
||||||
|
out=query_start_loc[1:])
|
||||||
|
torch.cumsum(kv_lens_tensor,
|
||||||
|
dim=0,
|
||||||
|
dtype=torch.int32,
|
||||||
|
out=kv_start_loc[1:])
|
||||||
|
max_query_len = max(prefill_query_lens)
|
||||||
|
max_kv_len = max(prefill_seq_lens)
|
||||||
|
else:
|
||||||
|
prefill_block_tables = None
|
||||||
|
query_start_loc = None
|
||||||
|
kv_start_loc = None
|
||||||
|
max_query_len = None
|
||||||
|
max_kv_len = None
|
||||||
|
|
||||||
|
# For paged attention
|
||||||
|
if input_data.num_decode_tokens != 0:
|
||||||
|
seq_lens_tensor = torch.tensor(
|
||||||
|
input_data.seq_lens[input_data.num_prefills:],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu",
|
||||||
|
)
|
||||||
|
block_tables = make_tensor_with_pad(
|
||||||
|
self.input_data.decode_block_tables,
|
||||||
|
pad=0,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
block_tables = torch.tensor([])
|
||||||
|
seq_lens_tensor = torch.tensor(
|
||||||
|
input_data.seq_lens[:input_data.num_prefills],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu",
|
||||||
|
)
|
||||||
|
|
||||||
|
# For multi-modal models
|
||||||
|
placeholder_index_maps = None
|
||||||
|
if len(input_data.multi_modal_inputs_list) != 0:
|
||||||
|
placeholder_index_maps = {
|
||||||
|
modality: placeholder_map.index_map()
|
||||||
|
for modality, placeholder_map in
|
||||||
|
input_data.multi_modal_placeholder_maps.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
attn_metadata = TorchSDPAMetadata(
|
||||||
|
chunked_prefill=self.chunked_prefill,
|
||||||
|
seq_lens=prefill_seq_lens,
|
||||||
|
seq_lens_tensor=seq_lens_tensor,
|
||||||
|
max_query_len=max_query_len,
|
||||||
|
max_kv_len=max_kv_len,
|
||||||
|
prefill_query_start_loc=query_start_loc,
|
||||||
|
kv_start_loc=kv_start_loc,
|
||||||
|
max_decode_seq_len=input_data.max_decode_seq_len,
|
||||||
|
num_prefills=input_data.num_prefills,
|
||||||
|
num_prefill_tokens=input_data.num_prefill_tokens,
|
||||||
|
num_decode_tokens=input_data.num_decode_tokens,
|
||||||
|
block_tables=block_tables,
|
||||||
|
prefill_block_tables=prefill_block_tables,
|
||||||
|
slot_mapping=slot_mapping,
|
||||||
|
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||||
|
enable_kv_scales_calculation=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
return attn_metadata
|
||||||
|
|
||||||
|
|
||||||
class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
@ -64,11 +64,13 @@ class CpuPlatform(Platform):
|
||||||
if selected_backend and selected_backend != _Backend.TORCH_SDPA:
|
if selected_backend and selected_backend != _Backend.TORCH_SDPA:
|
||||||
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
logger.info("Cannot use %s backend on CPU.", selected_backend)
|
||||||
if use_mla:
|
if use_mla:
|
||||||
raise NotImplementedError("MLA is not supported on CPU.")
|
logger.info("Using CPU MLA backend.")
|
||||||
|
return "vllm.attention.backends.cpu_mla.CPUMLABackend"
|
||||||
logger.info("Using Torch SDPA backend.")
|
logger.info("Using Torch SDPA backend.")
|
||||||
if not use_v1:
|
if use_v1:
|
||||||
raise ValueError("CPU backend only supports V1.")
|
return "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend"
|
||||||
return "vllm.v1.attention.backends.cpu_attn.TorchSDPABackend"
|
else:
|
||||||
|
return "vllm.attention.backends.torch_sdpa.TorchSDPABackend"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
def get_device_total_memory(cls, device_id: int = 0) -> int:
|
||||||
|
@ -145,14 +147,26 @@ class CpuPlatform(Platform):
|
||||||
parallel_config.distributed_executor_backend)
|
parallel_config.distributed_executor_backend)
|
||||||
parallel_config.distributed_executor_backend = "mp"
|
parallel_config.distributed_executor_backend = "mp"
|
||||||
if parallel_config.worker_cls == "auto":
|
if parallel_config.worker_cls == "auto":
|
||||||
parallel_config.worker_cls = "vllm.v1.worker.cpu_worker.CPUWorker"
|
if vllm_config.speculative_config:
|
||||||
|
parallel_config.worker_cls = \
|
||||||
|
"vllm.spec_decode.spec_decode_worker.create_spec_worker"
|
||||||
|
parallel_config.sd_worker_cls = \
|
||||||
|
"vllm.worker.cpu_worker.CPUWorker"
|
||||||
|
else:
|
||||||
|
if envs.VLLM_USE_V1:
|
||||||
|
parallel_config.worker_cls = \
|
||||||
|
"vllm.v1.worker.cpu_worker.CPUWorker"
|
||||||
|
else:
|
||||||
|
parallel_config.worker_cls = \
|
||||||
|
"vllm.worker.cpu_worker.CPUWorker"
|
||||||
|
|
||||||
# Note: workaround for v1 gpu_model_runner
|
# Note: workaround for v1 gpu_model_runner
|
||||||
from vllm.config import CompilationLevel
|
from vllm.config import CompilationLevel
|
||||||
vllm_config.compilation_config.cudagraph_capture_sizes = []
|
vllm_config.compilation_config.cudagraph_capture_sizes = []
|
||||||
|
|
||||||
compilation_config = vllm_config.compilation_config
|
compilation_config = vllm_config.compilation_config
|
||||||
if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE:
|
if (envs.VLLM_USE_V1 and vllm_config.compilation_config.level
|
||||||
|
== CompilationLevel.PIECEWISE):
|
||||||
|
|
||||||
# Note: vLLM V1 is using PIECEWISE level compilation, which will
|
# Note: vLLM V1 is using PIECEWISE level compilation, which will
|
||||||
# take time to compile kernels just-in-time with the inductor
|
# take time to compile kernels just-in-time with the inductor
|
||||||
|
|
|
@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Optional, Union, cast
|
||||||
import torch
|
import torch
|
||||||
from tpu_info import device
|
from tpu_info import device
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
from vllm.inputs import ProcessorInputs, PromptType
|
from vllm.inputs import ProcessorInputs, PromptType
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.sampling_params import SamplingParams, SamplingType
|
from vllm.sampling_params import SamplingParams, SamplingType
|
||||||
|
@ -49,10 +50,12 @@ class TpuPlatform(Platform):
|
||||||
and selected_backend != _Backend.PALLAS_VLLM_V1):
|
and selected_backend != _Backend.PALLAS_VLLM_V1):
|
||||||
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
logger.info("Cannot use %s backend on TPU.", selected_backend)
|
||||||
|
|
||||||
if not use_v1:
|
if use_v1:
|
||||||
raise ValueError("TPU backend only supports V1.")
|
logger.info("Using Pallas V1 backend.")
|
||||||
logger.info("Using Pallas V1 backend.")
|
return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
|
||||||
return "vllm.v1.attention.backends.pallas.PallasAttentionBackend"
|
else:
|
||||||
|
logger.info("Using Pallas backend.")
|
||||||
|
return "vllm.attention.backends.pallas.PallasAttentionBackend"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_device_name(cls, device_id: int = 0) -> str:
|
def get_device_name(cls, device_id: int = 0) -> str:
|
||||||
|
@ -65,7 +68,7 @@ class TpuPlatform(Platform):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
def is_async_output_supported(cls, enforce_eager: Optional[bool]) -> bool:
|
||||||
return False
|
return not envs.VLLM_USE_V1
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_punica_wrapper(cls) -> str:
|
def get_punica_wrapper(cls) -> str:
|
||||||
|
@ -114,19 +117,31 @@ class TpuPlatform(Platform):
|
||||||
"Using bfloat16 instead.", vllm_config.model_config.dtype)
|
"Using bfloat16 instead.", vllm_config.model_config.dtype)
|
||||||
vllm_config.model_config.dtype = torch.bfloat16
|
vllm_config.model_config.dtype = torch.bfloat16
|
||||||
|
|
||||||
from vllm.v1.attention.backends.pallas import PallasAttentionBackend
|
if envs.VLLM_USE_V1:
|
||||||
cache_config.block_size = PallasAttentionBackend.get_page_size(
|
from vllm.v1.attention.backends.pallas import (
|
||||||
vllm_config) # type: ignore[assignment]
|
PallasAttentionBackend)
|
||||||
|
cache_config.block_size = PallasAttentionBackend.get_page_size(
|
||||||
|
vllm_config) # type: ignore[assignment]
|
||||||
|
|
||||||
parallel_config = vllm_config.parallel_config
|
parallel_config = vllm_config.parallel_config
|
||||||
scheduler_config = vllm_config.scheduler_config
|
scheduler_config = vllm_config.scheduler_config
|
||||||
if parallel_config.worker_cls == "auto":
|
if parallel_config.worker_cls == "auto":
|
||||||
if scheduler_config.is_multi_step:
|
if scheduler_config.is_multi_step:
|
||||||
raise NotImplementedError(
|
if envs.VLLM_USE_V1:
|
||||||
"Multi-step scheduling is not supported (and not "
|
raise NotImplementedError(
|
||||||
"needed) on vLLM V1. Please launch without "
|
"Multi-step scheduling is not supported (and not "
|
||||||
"--num-scheduler-steps.")
|
"needed) on vLLM V1. Please launch without "
|
||||||
parallel_config.worker_cls = "vllm.v1.worker.tpu_worker.TPUWorker"
|
"--num-scheduler-steps.")
|
||||||
|
else:
|
||||||
|
parallel_config.worker_cls = \
|
||||||
|
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
|
||||||
|
else:
|
||||||
|
if envs.VLLM_USE_V1:
|
||||||
|
parallel_config.worker_cls = \
|
||||||
|
"vllm.v1.worker.tpu_worker.TPUWorker"
|
||||||
|
else:
|
||||||
|
parallel_config.worker_cls = \
|
||||||
|
"vllm.worker.tpu_worker.TPUWorker"
|
||||||
|
|
||||||
assert not vllm_config.speculative_config, (
|
assert not vllm_config.speculative_config, (
|
||||||
"Speculative decoding is not yet supported for TPU backend")
|
"Speculative decoding is not yet supported for TPU backend")
|
||||||
|
@ -174,9 +189,13 @@ class TpuPlatform(Platform):
|
||||||
processed_inputs: ProcessorInputs,
|
processed_inputs: ProcessorInputs,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Raises if this request is unsupported on this platform"""
|
"""Raises if this request is unsupported on this platform"""
|
||||||
if (isinstance(params, SamplingParams)
|
if isinstance(params, SamplingParams):
|
||||||
and params.sampling_type == SamplingType.RANDOM_SEED):
|
if params.guided_decoding is not None and not envs.VLLM_USE_V1:
|
||||||
raise ValueError("Torch XLA does not support per-request seed.")
|
raise ValueError("Structured output is not supported on "
|
||||||
|
f"{cls.device_name} V0.")
|
||||||
|
if params.sampling_type == SamplingType.RANDOM_SEED:
|
||||||
|
raise ValueError(
|
||||||
|
"Torch XLA does not support per-request seed.")
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -39,10 +39,12 @@ class XPUPlatform(Platform):
|
||||||
if selected_backend != _Backend.IPEX:
|
if selected_backend != _Backend.IPEX:
|
||||||
logger.info("Cannot use %s backend on XPU.", selected_backend)
|
logger.info("Cannot use %s backend on XPU.", selected_backend)
|
||||||
use_v1 = envs.VLLM_USE_V1
|
use_v1 = envs.VLLM_USE_V1
|
||||||
if not use_v1:
|
if use_v1:
|
||||||
raise ValueError("XPU backend only supports V1.")
|
logger.info("Using Flash Attention backend on V1 engine.")
|
||||||
logger.info("Using Flash Attention backend on V1 engine.")
|
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
||||||
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
|
else:
|
||||||
|
logger.info("Using IPEX attention backend.")
|
||||||
|
return "vllm.attention.backends.ipex_attn.IpexAttnBackend"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_device_capability(
|
def get_device_capability(
|
||||||
|
@ -75,7 +77,10 @@ class XPUPlatform(Platform):
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
# in V1(or with ipex chunked prefill) block_size is 64
|
# in V1(or with ipex chunked prefill) block_size is 64
|
||||||
if cache_config and cache_config.block_size is None:
|
if cache_config and cache_config.block_size is None:
|
||||||
cache_config.block_size = 64
|
if envs.VLLM_USE_V1:
|
||||||
|
cache_config.block_size = 64
|
||||||
|
else:
|
||||||
|
cache_config.block_size = 16
|
||||||
|
|
||||||
# Instances created using VllmConfig() typically have model_config as
|
# Instances created using VllmConfig() typically have model_config as
|
||||||
# None by default. The modification involves adding a check to prevent
|
# None by default. The modification involves adding a check to prevent
|
||||||
|
@ -101,7 +106,11 @@ class XPUPlatform(Platform):
|
||||||
|
|
||||||
# check and update parallel config
|
# check and update parallel config
|
||||||
parallel_config = vllm_config.parallel_config
|
parallel_config = vllm_config.parallel_config
|
||||||
parallel_config.worker_cls = "vllm.v1.worker.xpu_worker.XPUWorker"
|
if envs.VLLM_USE_V1:
|
||||||
|
parallel_config.worker_cls =\
|
||||||
|
"vllm.v1.worker.xpu_worker.XPUWorker"
|
||||||
|
else:
|
||||||
|
parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker"
|
||||||
|
|
||||||
if parallel_config.distributed_executor_backend is None:
|
if parallel_config.distributed_executor_backend is None:
|
||||||
if parallel_config.world_size > 1:
|
if parallel_config.world_size > 1:
|
||||||
|
|
|
@ -0,0 +1,326 @@
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, cast
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.attention import AttentionMetadata
|
||||||
|
from vllm.forward_context import set_forward_context
|
||||||
|
from vllm.model_executor import SamplingMetadata
|
||||||
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
|
from vllm.multimodal import MultiModalKwargs
|
||||||
|
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
|
||||||
|
from vllm.utils import make_tensor_with_pad
|
||||||
|
from vllm.worker.cpu_model_runner import (CPUModelRunnerBase,
|
||||||
|
ModelInputForCPUBuilder,
|
||||||
|
ModelInputForCPUWithSamplingMetadata)
|
||||||
|
from vllm.worker.model_runner_base import (
|
||||||
|
_add_attn_metadata_broadcastable_dict,
|
||||||
|
_add_sampling_metadata_broadcastable_dict)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
|
class EncoderDecoderModelInputForCPU(ModelInputForCPUWithSamplingMetadata):
|
||||||
|
"""
|
||||||
|
Used by the EncoderDecoderModelRunner.
|
||||||
|
"""
|
||||||
|
encoder_input_tokens: Optional[torch.Tensor] = None
|
||||||
|
encoder_input_positions: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||||
|
tensor_dict = {
|
||||||
|
"input_tokens": self.input_tokens,
|
||||||
|
"input_positions": self.input_positions,
|
||||||
|
"encoder_input_tokens": self.encoder_input_tokens,
|
||||||
|
"encoder_input_positions": self.encoder_input_positions,
|
||||||
|
"multi_modal_kwargs": self.multi_modal_kwargs,
|
||||||
|
}
|
||||||
|
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||||
|
_add_sampling_metadata_broadcastable_dict(tensor_dict,
|
||||||
|
self.sampling_metadata)
|
||||||
|
return tensor_dict
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_broadcasted_tensor_dict(
|
||||||
|
cls,
|
||||||
|
tensor_dict: Dict[str, Any],
|
||||||
|
attn_backend: Optional["AttentionBackend"] = None,
|
||||||
|
) -> "EncoderDecoderModelInputForCPU":
|
||||||
|
return cast(
|
||||||
|
EncoderDecoderModelInputForCPU,
|
||||||
|
super().from_broadcasted_tensor_dict(tensor_dict, attn_backend))
|
||||||
|
|
||||||
|
|
||||||
|
class CPUEncoderDecoderModelRunner(
|
||||||
|
CPUModelRunnerBase[EncoderDecoderModelInputForCPU]):
|
||||||
|
_model_input_cls: Type[EncoderDecoderModelInputForCPU] = (
|
||||||
|
EncoderDecoderModelInputForCPU)
|
||||||
|
_builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder
|
||||||
|
|
||||||
|
def _list_to_int32_tensor(
|
||||||
|
self,
|
||||||
|
_list: List[int],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.tensor(_list, dtype=torch.int32, device=self.device)
|
||||||
|
|
||||||
|
def _list_to_long_tensor(
|
||||||
|
self,
|
||||||
|
_list: List[int],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
return torch.tensor(_list, dtype=torch.long, device=self.device)
|
||||||
|
|
||||||
|
def _empty_int32_tensor(self) -> torch.Tensor:
|
||||||
|
return self._list_to_int32_tensor([])
|
||||||
|
|
||||||
|
def _empty_long_tensor(self) -> torch.Tensor:
|
||||||
|
return self._list_to_long_tensor([])
|
||||||
|
|
||||||
|
def make_model_input_from_broadcasted_tensor_dict(
|
||||||
|
self, tensor_dict: Dict[str,
|
||||||
|
Any]) -> EncoderDecoderModelInputForCPU:
|
||||||
|
return EncoderDecoderModelInputForCPU.from_broadcasted_tensor_dict(
|
||||||
|
tensor_dict,
|
||||||
|
attn_backend=self.attn_backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_model_input(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
virtual_engine: int = 0,
|
||||||
|
finished_requests_ids: Optional[List[str]] = None
|
||||||
|
) -> EncoderDecoderModelInputForCPU:
|
||||||
|
model_input = self._prepare_model_input_tensors(
|
||||||
|
seq_group_metadata_list, finished_requests_ids)
|
||||||
|
(
|
||||||
|
attn_metadata,
|
||||||
|
encoder_input_tokens_tensor,
|
||||||
|
encoder_input_positions_tensor,
|
||||||
|
) = self._prepare_encoder_model_input_tensors(seq_group_metadata_list,
|
||||||
|
model_input)
|
||||||
|
# Sampling metadata is only required for the final pp group
|
||||||
|
generators = self.get_generators(finished_requests_ids)
|
||||||
|
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
|
||||||
|
model_input.seq_lens,
|
||||||
|
model_input.query_lens,
|
||||||
|
self.device,
|
||||||
|
pin_memory=False,
|
||||||
|
generators=generators)
|
||||||
|
return dataclasses.replace(
|
||||||
|
model_input,
|
||||||
|
sampling_metadata=sampling_metadata,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
encoder_input_tokens=encoder_input_tokens_tensor,
|
||||||
|
encoder_input_positions=encoder_input_positions_tensor,
|
||||||
|
virtual_engine=virtual_engine,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _prepare_encoder_model_input_tensors(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
model_input: EncoderDecoderModelInputForCPU,
|
||||||
|
) -> Tuple[AttentionMetadata, Optional[torch.Tensor],
|
||||||
|
Optional[torch.Tensor]]:
|
||||||
|
"""Helper method to prepare the encoder- and cross-attn-related
|
||||||
|
model inputs based on a given sequence group. These additional inputs
|
||||||
|
are used to augment an already-computed `EncoderDecoderModelInput`
|
||||||
|
data structure which already has decoder-related model inputs
|
||||||
|
populated.
|
||||||
|
|
||||||
|
Sets the following attn_metadata fields:
|
||||||
|
* `num_encoder_tokens`
|
||||||
|
* `encoder_seq_lens`
|
||||||
|
* `encoder_seq_lens_tensor`
|
||||||
|
* `max_encoder_seq_len`
|
||||||
|
* `cross_slot_mapping`
|
||||||
|
* `cross_block_tables`
|
||||||
|
|
||||||
|
Constructs a new model inputs data structure, based on
|
||||||
|
(1) the existing fields in the `model_inputs` argument,
|
||||||
|
and (2) the following additional fields which are
|
||||||
|
computed (or in the case of `attn_metadata`, updated)
|
||||||
|
by this function:
|
||||||
|
* attn_metadata
|
||||||
|
* encoder_input_tokens
|
||||||
|
* encoder_input_positions
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
|
||||||
|
* seq_group_metadata_list: list of sequence groups for which to
|
||||||
|
compute inputs
|
||||||
|
* model_inputs: model inputs data structure with decoder-oriented
|
||||||
|
fields already computed.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
|
||||||
|
* Updated model inputs data structure
|
||||||
|
"""
|
||||||
|
|
||||||
|
if len(seq_group_metadata_list) == 0:
|
||||||
|
return (model_input.attn_metadata, None, None)
|
||||||
|
|
||||||
|
# Since we are not supporting chunked prefill either the entire
|
||||||
|
# batch is prefill or it is decode
|
||||||
|
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||||
|
|
||||||
|
# Build encoder inputs
|
||||||
|
encoder_seq_lens: List[int] = []
|
||||||
|
if is_prompt:
|
||||||
|
# Prefill phase.
|
||||||
|
cross_block_tables = self._empty_int32_tensor().view(
|
||||||
|
len(seq_group_metadata_list), -1)
|
||||||
|
|
||||||
|
# Extract input tokens/positions, cross-attention slot-mapping,
|
||||||
|
# & seq len from each sequence group metadata
|
||||||
|
(
|
||||||
|
encoder_input_tokens,
|
||||||
|
encoder_input_positions,
|
||||||
|
cross_slot_mapping,
|
||||||
|
) = (
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
[],
|
||||||
|
)
|
||||||
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
|
# Build seq lens
|
||||||
|
seq_len = seq_group_metadata.encoder_seq_data.get_len()
|
||||||
|
token_ids = seq_group_metadata.encoder_seq_data.get_token_ids()
|
||||||
|
encoder_seq_lens.append(seq_len)
|
||||||
|
|
||||||
|
# Build slot mapping
|
||||||
|
for i in range(0, seq_len):
|
||||||
|
block_number = seq_group_metadata.cross_block_table[
|
||||||
|
i // self.block_size]
|
||||||
|
block_offset = i % self.block_size
|
||||||
|
slot = block_number * self.block_size + block_offset
|
||||||
|
cross_slot_mapping.append(slot)
|
||||||
|
|
||||||
|
# Build encoder input tokens
|
||||||
|
encoder_input_tokens.extend(token_ids)
|
||||||
|
encoder_input_positions.extend(list(range(0, seq_len)))
|
||||||
|
|
||||||
|
# Convert tokens/positions & cross-attention
|
||||||
|
# slot-mapping to encoder input tensors
|
||||||
|
encoder_input_tokens_tensor = self._list_to_long_tensor(
|
||||||
|
encoder_input_tokens)
|
||||||
|
encoder_input_positions_tensor = self._list_to_long_tensor(
|
||||||
|
encoder_input_positions)
|
||||||
|
cross_slot_mapping_tensor = self._list_to_long_tensor(
|
||||||
|
cross_slot_mapping)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Decode phase.
|
||||||
|
encoder_input_tokens_tensor = self._empty_long_tensor()
|
||||||
|
encoder_input_positions_tensor = self._empty_long_tensor()
|
||||||
|
cross_slot_mapping_tensor = self._empty_long_tensor()
|
||||||
|
# Extract cross-attention block tables &
|
||||||
|
# seq len from each sequence group metadata.
|
||||||
|
# Cross-attention block tables are empty
|
||||||
|
# during vLLM memory profiling.
|
||||||
|
cross_block_tables = []
|
||||||
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
|
for _ in range(len(seq_group_metadata.seq_data)):
|
||||||
|
encoder_seq_lens.append(
|
||||||
|
seq_group_metadata.encoder_seq_data.get_len())
|
||||||
|
cross_block_table = seq_group_metadata.cross_block_table
|
||||||
|
cross_block_tables.append([] if (
|
||||||
|
cross_block_table is None) else cross_block_table)
|
||||||
|
|
||||||
|
max_len_of_block_table = max(
|
||||||
|
len(block_table) for block_table in cross_block_tables)
|
||||||
|
|
||||||
|
cross_block_tables = make_tensor_with_pad(
|
||||||
|
cross_block_tables,
|
||||||
|
max_len=max_len_of_block_table,
|
||||||
|
pad=0,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute encoder sequence lengths & encoder
|
||||||
|
# sequence starting offset tensors
|
||||||
|
max_encoder_seq_len = max(encoder_seq_lens, default=0)
|
||||||
|
encoder_seq_lens_tensor = self._list_to_int32_tensor(encoder_seq_lens)
|
||||||
|
encoder_seq_start_loc = torch.zeros(encoder_seq_lens_tensor.shape[0] +
|
||||||
|
1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
|
torch.cumsum(encoder_seq_lens_tensor,
|
||||||
|
dim=0,
|
||||||
|
dtype=encoder_seq_start_loc.dtype,
|
||||||
|
out=encoder_seq_start_loc[1:])
|
||||||
|
|
||||||
|
# Update attention metadata with encoder-oriented attributes
|
||||||
|
attn_metadata = model_input.attn_metadata
|
||||||
|
assert attn_metadata is not None
|
||||||
|
(
|
||||||
|
attn_metadata.num_encoder_tokens,
|
||||||
|
attn_metadata.encoder_seq_lens,
|
||||||
|
attn_metadata.encoder_seq_lens_tensor,
|
||||||
|
attn_metadata.max_encoder_seq_len,
|
||||||
|
attn_metadata.cross_slot_mapping,
|
||||||
|
attn_metadata.cross_block_tables,
|
||||||
|
) = (
|
||||||
|
sum(encoder_seq_lens),
|
||||||
|
encoder_seq_lens,
|
||||||
|
encoder_seq_lens_tensor,
|
||||||
|
max_encoder_seq_len,
|
||||||
|
cross_slot_mapping_tensor,
|
||||||
|
cross_block_tables,
|
||||||
|
)
|
||||||
|
|
||||||
|
return (attn_metadata, encoder_input_tokens_tensor,
|
||||||
|
encoder_input_positions_tensor)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
model_input: EncoderDecoderModelInputForCPU,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
num_steps: int = 1,
|
||||||
|
) -> Optional[List[SamplerOutput]]:
|
||||||
|
if num_steps > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"CPU worker does not support multi-step execution.")
|
||||||
|
|
||||||
|
model_executable = self.model
|
||||||
|
execute_model_kwargs = {
|
||||||
|
"input_ids":
|
||||||
|
model_input.input_tokens,
|
||||||
|
"positions":
|
||||||
|
model_input.input_positions,
|
||||||
|
"encoder_input_ids":
|
||||||
|
model_input.encoder_input_tokens,
|
||||||
|
"encoder_positions":
|
||||||
|
model_input.encoder_input_positions,
|
||||||
|
**MultiModalKwargs.as_kwargs(
|
||||||
|
model_input.multi_modal_kwargs or {},
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
"intermediate_tensors":
|
||||||
|
intermediate_tensors,
|
||||||
|
}
|
||||||
|
|
||||||
|
with set_forward_context(model_input.attn_metadata, self.vllm_config,
|
||||||
|
model_input.virtual_engine):
|
||||||
|
hidden_states = model_executable(**execute_model_kwargs)
|
||||||
|
|
||||||
|
# Compute the logits.
|
||||||
|
logits = self.model.compute_logits(hidden_states,
|
||||||
|
model_input.sampling_metadata)
|
||||||
|
|
||||||
|
# Only perform sampling in the driver worker.
|
||||||
|
if not self.is_driver_worker:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Sample the next token.
|
||||||
|
output = self.sampler(
|
||||||
|
logits=logits,
|
||||||
|
sampling_metadata=model_input.sampling_metadata,
|
||||||
|
)
|
||||||
|
return [output]
|
|
@ -0,0 +1,671 @@
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
import weakref
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Type,
|
||||||
|
TypeVar, Union)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.forward_context import set_forward_context
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.lora.layers import LoRAMapping
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager
|
||||||
|
from vllm.model_executor import SamplingMetadata
|
||||||
|
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
|
||||||
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||||
|
from vllm.model_executor.model_loader import get_model
|
||||||
|
from vllm.model_executor.models import supports_lora, supports_multimodal
|
||||||
|
from vllm.multimodal import (BatchedTensorInputs, MultiModalKwargs,
|
||||||
|
MultiModalPlaceholderMap)
|
||||||
|
from vllm.sequence import (IntermediateTensors, SequenceData,
|
||||||
|
SequenceGroupMetadata)
|
||||||
|
from vllm.worker.model_runner_base import (
|
||||||
|
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
|
||||||
|
_add_attn_metadata_broadcastable_dict,
|
||||||
|
_add_sampling_metadata_broadcastable_dict,
|
||||||
|
_init_attn_metadata_from_tensor_dict,
|
||||||
|
_init_sampling_metadata_from_tensor_dict)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
TModelInputForCPU = TypeVar('TModelInputForCPU', bound="ModelInputForCPU")
|
||||||
|
_PAD_SLOT_ID = -1
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ModelInputForCPU(ModelRunnerInputBase):
|
||||||
|
"""
|
||||||
|
Base class contains metadata needed for the base model forward pass on CPU
|
||||||
|
"""
|
||||||
|
input_tokens: Optional[torch.Tensor] = None
|
||||||
|
input_positions: Optional[torch.Tensor] = None
|
||||||
|
token_type_ids: Optional[torch.Tensor] = None
|
||||||
|
attn_metadata: Optional["AttentionMetadata"] = None
|
||||||
|
multi_modal_kwargs: Optional[BatchedTensorInputs] = None
|
||||||
|
virtual_engine: Optional[int] = None
|
||||||
|
seq_lens: Optional[List[int]] = None
|
||||||
|
query_lens: Optional[List[int]] = None
|
||||||
|
lora_mapping: Optional["LoRAMapping"] = None
|
||||||
|
lora_requests: Optional[Set[LoRARequest]] = None
|
||||||
|
|
||||||
|
def as_broadcastable_tensor_dict(
|
||||||
|
self) -> Dict[str, Union[int, torch.Tensor]]:
|
||||||
|
tensor_dict = {
|
||||||
|
"input_tokens": self.input_tokens,
|
||||||
|
"input_positions": self.input_positions,
|
||||||
|
"token_type_ids": self.token_type_ids,
|
||||||
|
"multi_modal_kwargs": self.multi_modal_kwargs,
|
||||||
|
"lora_requests": self.lora_requests,
|
||||||
|
"lora_mapping": self.lora_mapping,
|
||||||
|
}
|
||||||
|
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||||
|
|
||||||
|
return tensor_dict
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_broadcasted_tensor_dict(
|
||||||
|
cls: Type[TModelInputForCPU],
|
||||||
|
tensor_dict: Dict[str, Any],
|
||||||
|
attn_backend: Optional["AttentionBackend"] = None
|
||||||
|
) -> TModelInputForCPU:
|
||||||
|
if attn_backend is not None:
|
||||||
|
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
||||||
|
attn_backend, tensor_dict)
|
||||||
|
return cls(**tensor_dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ModelInputForCPUWithSamplingMetadata(ModelInputForCPU):
|
||||||
|
"""
|
||||||
|
Used by the ModelRunner.
|
||||||
|
"""
|
||||||
|
sampling_metadata: Optional["SamplingMetadata"] = None
|
||||||
|
is_prompt: Optional[bool] = None
|
||||||
|
|
||||||
|
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||||
|
tensor_dict = {
|
||||||
|
"input_tokens": self.input_tokens,
|
||||||
|
"input_positions": self.input_positions,
|
||||||
|
"token_type_ids": self.token_type_ids,
|
||||||
|
"multi_modal_kwargs": self.multi_modal_kwargs,
|
||||||
|
}
|
||||||
|
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||||
|
_add_sampling_metadata_broadcastable_dict(tensor_dict,
|
||||||
|
self.sampling_metadata)
|
||||||
|
return tensor_dict
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_broadcasted_tensor_dict(
|
||||||
|
cls,
|
||||||
|
tensor_dict: Dict[str, Any],
|
||||||
|
attn_backend: Optional["AttentionBackend"] = None,
|
||||||
|
) -> "ModelInputForCPUWithSamplingMetadata":
|
||||||
|
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
|
||||||
|
if attn_backend is not None:
|
||||||
|
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
||||||
|
attn_backend, tensor_dict)
|
||||||
|
return cls(**tensor_dict)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]):
|
||||||
|
|
||||||
|
class ModelInputData:
|
||||||
|
|
||||||
|
def __init__(self, use_mrope: bool):
|
||||||
|
self.use_mrope = use_mrope
|
||||||
|
self.input_tokens: List[int] = []
|
||||||
|
self.input_positions: List[int] = []
|
||||||
|
self.token_type_ids: Optional[List[int]] = []
|
||||||
|
self.seq_lens: List[int] = []
|
||||||
|
self.query_lens: List[int] = []
|
||||||
|
self.prefill_block_tables: List[List[int]] = []
|
||||||
|
self.decode_block_tables: List[List[int]] = []
|
||||||
|
self.max_decode_seq_len: int = 0
|
||||||
|
self.num_prefills: int = 0
|
||||||
|
self.num_prefill_tokens: int = 0
|
||||||
|
self.num_decode_tokens: int = 0
|
||||||
|
self.slot_mapping: List[int] = []
|
||||||
|
self.multi_modal_inputs_list: List[MultiModalKwargs] = []
|
||||||
|
self.multi_modal_placeholder_maps: Dict[
|
||||||
|
str, MultiModalPlaceholderMap] = defaultdict(
|
||||||
|
MultiModalPlaceholderMap)
|
||||||
|
self.input_mrope_positions: List[List[int]] = [[]
|
||||||
|
for _ in range(3)]
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
runner: "CPUModelRunner",
|
||||||
|
finished_requests_ids: Optional[List[str]] = None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.runner = runner
|
||||||
|
self.chunked_prefill = (runner.scheduler_config.chunked_prefill_enabled
|
||||||
|
or runner.cache_config.enable_prefix_caching)
|
||||||
|
self.model_input_cls = self.runner._model_input_cls
|
||||||
|
self.attn_backend = self.runner.attn_backend
|
||||||
|
self.sliding_window = self.runner.sliding_window
|
||||||
|
self.block_size = self.runner.block_size
|
||||||
|
self.device = self.runner.device
|
||||||
|
self.enable_lora = self.runner.lora_config is not None
|
||||||
|
if self.runner.attn_backend is not None:
|
||||||
|
# spec decode (e.g. Medusa) does not have atten backend
|
||||||
|
attn_backend = self.runner.attn_backend
|
||||||
|
self.att_metadata_builder = attn_backend.get_builder_cls()(self)
|
||||||
|
|
||||||
|
def prepare(self,
|
||||||
|
finished_requests_ids: Optional[List[str]] = None) -> None:
|
||||||
|
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||||
|
self.input_data = ModelInputForCPUBuilder.ModelInputData(
|
||||||
|
self.runner.model_config.uses_mrope)
|
||||||
|
self.att_metadata_builder.prepare()
|
||||||
|
|
||||||
|
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
|
||||||
|
self.seq_group_metadata_list.append(seq_group_metadata)
|
||||||
|
|
||||||
|
def set_seq_group_list(
|
||||||
|
self, seq_group_metadata_list: List[SequenceGroupMetadata]):
|
||||||
|
self.seq_group_metadata_list = seq_group_metadata_list
|
||||||
|
|
||||||
|
def build(self) -> ModelInputForCPU:
|
||||||
|
self._build_input_data()
|
||||||
|
|
||||||
|
input_data = self.input_data
|
||||||
|
input_tokens = torch.tensor(input_data.input_tokens,
|
||||||
|
dtype=torch.long,
|
||||||
|
device="cpu")
|
||||||
|
input_positions = torch.tensor(
|
||||||
|
input_data.input_positions
|
||||||
|
if not any(input_data.input_mrope_positions) else
|
||||||
|
input_data.input_mrope_positions,
|
||||||
|
dtype=torch.long,
|
||||||
|
device="cpu")
|
||||||
|
token_type_ids = torch.tensor(input_data.token_type_ids,
|
||||||
|
dtype=torch.long,
|
||||||
|
device="cpu") \
|
||||||
|
if input_data.token_type_ids else None
|
||||||
|
|
||||||
|
# For multi-modal models
|
||||||
|
multi_modal_kwargs = None
|
||||||
|
if len(input_data.multi_modal_inputs_list) != 0:
|
||||||
|
multi_modal_kwargs = MultiModalKwargs.batch(
|
||||||
|
input_data.multi_modal_inputs_list)
|
||||||
|
|
||||||
|
attn_metadata = self.att_metadata_builder.build(
|
||||||
|
input_data.seq_lens, input_data.query_lens, -1, -1)
|
||||||
|
|
||||||
|
is_prompt = (self.seq_group_metadata_list[0].is_prompt
|
||||||
|
if self.seq_group_metadata_list else None)
|
||||||
|
# LoRA data.
|
||||||
|
lora_requests = set()
|
||||||
|
lora_mapping = None
|
||||||
|
if self.enable_lora:
|
||||||
|
lora_requests = set(seq.lora_request
|
||||||
|
for seq in self.seq_group_metadata_list
|
||||||
|
if seq.lora_request is not None)
|
||||||
|
|
||||||
|
lora_mapping = self._prepare_lora_input(
|
||||||
|
self.seq_group_metadata_list, is_prompt)
|
||||||
|
|
||||||
|
return self.model_input_cls(input_tokens=input_tokens,
|
||||||
|
input_positions=input_positions,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
seq_lens=input_data.seq_lens,
|
||||||
|
query_lens=input_data.query_lens,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
multi_modal_kwargs=multi_modal_kwargs,
|
||||||
|
lora_mapping=lora_mapping,
|
||||||
|
lora_requests=lora_requests)
|
||||||
|
|
||||||
|
def _build_input_data(self):
|
||||||
|
for seq_group_metadata in self.seq_group_metadata_list:
|
||||||
|
for seq_id, seq_data in seq_group_metadata.seq_data.items():
|
||||||
|
if seq_group_metadata.is_prompt:
|
||||||
|
self._compute_prompt_input_tokens(self.input_data,
|
||||||
|
seq_group_metadata,
|
||||||
|
seq_data, seq_id)
|
||||||
|
if seq_group_metadata.multi_modal_data:
|
||||||
|
self._compute_multi_modal_input(
|
||||||
|
seq_group_metadata, seq_data)
|
||||||
|
else:
|
||||||
|
self._compute_decode_input_tokens(self.input_data,
|
||||||
|
seq_group_metadata,
|
||||||
|
seq_data, seq_id)
|
||||||
|
|
||||||
|
def _compute_decode_input_tokens(self, data: ModelInputData,
|
||||||
|
seq_group_metadata: SequenceGroupMetadata,
|
||||||
|
seq_data: SequenceData, seq_id: int):
|
||||||
|
"""
|
||||||
|
Compute decode input tokens, positions, block table and slot mapping.
|
||||||
|
"""
|
||||||
|
block_size = self.runner.block_size
|
||||||
|
|
||||||
|
block_table = seq_group_metadata.block_tables[seq_id]
|
||||||
|
seq_len = seq_data.get_len()
|
||||||
|
context_len = seq_data.get_num_computed_tokens()
|
||||||
|
|
||||||
|
tokens = seq_data.get_last_token_id()
|
||||||
|
token_positions = seq_len - 1
|
||||||
|
block_number = block_table[token_positions // block_size]
|
||||||
|
block_offset = token_positions % block_size
|
||||||
|
slot = block_number * block_size + block_offset
|
||||||
|
|
||||||
|
# For paged_attention kernel
|
||||||
|
if self.runner.sliding_window:
|
||||||
|
start_idx = max(0, seq_len - self.runner.sliding_window)
|
||||||
|
start_block = start_idx // block_size
|
||||||
|
start_idx = start_block * block_size
|
||||||
|
seq_len = seq_len - start_idx
|
||||||
|
block_table = block_table[start_block:]
|
||||||
|
|
||||||
|
# For MRotaryEmbedding
|
||||||
|
if seq_data.mrope_position_delta is not None:
|
||||||
|
next_pos = MRotaryEmbedding.get_next_input_positions(
|
||||||
|
seq_data.mrope_position_delta,
|
||||||
|
context_len,
|
||||||
|
seq_len,
|
||||||
|
)
|
||||||
|
for idx in range(3):
|
||||||
|
data.input_mrope_positions[idx].extend( # type: ignore
|
||||||
|
next_pos[idx])
|
||||||
|
else:
|
||||||
|
data.input_positions.append(token_positions) # type: ignore
|
||||||
|
|
||||||
|
# Update fields
|
||||||
|
data.input_tokens.append(tokens)
|
||||||
|
data.max_decode_seq_len = max(data.max_decode_seq_len, seq_len)
|
||||||
|
data.num_decode_tokens += 1
|
||||||
|
data.slot_mapping.append(slot)
|
||||||
|
data.decode_block_tables.append(block_table)
|
||||||
|
data.query_lens.append(1)
|
||||||
|
data.seq_lens.append(seq_len)
|
||||||
|
|
||||||
|
def _compute_prompt_input_tokens(self, data: ModelInputData,
|
||||||
|
seq_group_metadata: SequenceGroupMetadata,
|
||||||
|
seq_data: SequenceData, seq_id: int):
|
||||||
|
"""
|
||||||
|
Compute prompt input tokens, positions, block table and slot mapping.
|
||||||
|
"""
|
||||||
|
token_chunk_size = seq_group_metadata.token_chunk_size
|
||||||
|
block_size = self.runner.block_size
|
||||||
|
|
||||||
|
block_table = seq_group_metadata.block_tables[seq_id]
|
||||||
|
seq_len = seq_data.get_len()
|
||||||
|
context_len = seq_data.get_num_computed_tokens()
|
||||||
|
seq_len = min(seq_len, context_len + token_chunk_size)
|
||||||
|
|
||||||
|
# For prefix caching
|
||||||
|
prefix_cache_block_num = len(seq_group_metadata.computed_block_nums)
|
||||||
|
if prefix_cache_block_num > 0:
|
||||||
|
prefix_cache_len = (prefix_cache_block_num *
|
||||||
|
self.runner.block_size)
|
||||||
|
if prefix_cache_len <= context_len:
|
||||||
|
# We already passed the cache hit region,
|
||||||
|
# so do normal computation.
|
||||||
|
pass
|
||||||
|
elif context_len < prefix_cache_len < seq_len:
|
||||||
|
# Partial hit. Compute the missing part.
|
||||||
|
context_len = prefix_cache_len
|
||||||
|
token_chunk_size = seq_len - context_len
|
||||||
|
elif seq_len <= prefix_cache_len:
|
||||||
|
# Full hit. Only compute the last token to avoid
|
||||||
|
# erroneous behavior. FIXME: Ideally we should directly
|
||||||
|
# mark all tokens as computed in the scheduler and do not
|
||||||
|
# schedule this sequence, so this case should not happen.
|
||||||
|
context_len = seq_len - 1
|
||||||
|
token_chunk_size = 1
|
||||||
|
|
||||||
|
tokens = seq_data.get_token_ids()
|
||||||
|
tokens = tokens[context_len:seq_len]
|
||||||
|
token_positions = range(context_len, seq_len)
|
||||||
|
token_types = seq_group_metadata.token_type_ids
|
||||||
|
|
||||||
|
# For encoder-only models, the block_table is None,
|
||||||
|
# and there is no need to initialize the slot_mapping.
|
||||||
|
if block_table is not None:
|
||||||
|
slot_mapping = [_PAD_SLOT_ID] * len(token_positions)
|
||||||
|
for i, pos in enumerate(token_positions):
|
||||||
|
block_number = block_table[pos // block_size]
|
||||||
|
block_offset = pos % block_size
|
||||||
|
slot = block_number * block_size + block_offset
|
||||||
|
slot_mapping[i] = slot
|
||||||
|
data.slot_mapping.extend(slot_mapping)
|
||||||
|
|
||||||
|
# The MROPE positions are prepared in _compute_multi_modal_input
|
||||||
|
data.input_positions.extend(token_positions)
|
||||||
|
|
||||||
|
if data.token_type_ids is not None:
|
||||||
|
data.token_type_ids.extend(token_types if token_types else [])
|
||||||
|
|
||||||
|
# Update fields
|
||||||
|
data.input_tokens.extend(tokens)
|
||||||
|
data.num_prefills += 1
|
||||||
|
data.num_prefill_tokens += len(tokens)
|
||||||
|
data.query_lens.append(len(tokens))
|
||||||
|
data.prefill_block_tables.append(block_table)
|
||||||
|
data.seq_lens.append(seq_len)
|
||||||
|
|
||||||
|
def _compute_multi_modal_input(self,
|
||||||
|
seq_group_metadata: SequenceGroupMetadata,
|
||||||
|
seq_data: SequenceData):
|
||||||
|
computed_len = seq_data.get_num_computed_tokens()
|
||||||
|
seq_len = self.input_data.seq_lens[-1]
|
||||||
|
|
||||||
|
# NOTE: mm_kwargs only includes the subset of multi-modal items that
|
||||||
|
# intersect with the current prefill positions.
|
||||||
|
mm_kwargs, placeholder_maps = MultiModalPlaceholderMap.from_seq_group(
|
||||||
|
seq_group_metadata, range(computed_len, seq_len))
|
||||||
|
|
||||||
|
if not mm_kwargs:
|
||||||
|
return
|
||||||
|
|
||||||
|
# special processing for mrope position deltas.
|
||||||
|
if self.runner.model_config.uses_mrope:
|
||||||
|
assert not self.chunked_prefill, \
|
||||||
|
"MROPE on CPU does not support chunked-prefill."
|
||||||
|
|
||||||
|
image_grid_thw = mm_kwargs.get("image_grid_thw", None)
|
||||||
|
video_grid_thw = mm_kwargs.get("video_grid_thw", None)
|
||||||
|
audio_feature_lengths = mm_kwargs.get("audio_feature_lengths",
|
||||||
|
None)
|
||||||
|
assert (
|
||||||
|
image_grid_thw is not None or video_grid_thw is not None
|
||||||
|
or audio_feature_lengths is not None), (
|
||||||
|
"mrope embedding type requires multi-modal input mapper "
|
||||||
|
"returns 'image_grid_thw' or 'video_grid_thw' or "
|
||||||
|
"'audio_feature_lengths'.")
|
||||||
|
|
||||||
|
second_per_grid_ts = mm_kwargs.get("second_per_grid_ts", None)
|
||||||
|
use_audio_in_video = mm_kwargs.get("use_audio_in_video", False)
|
||||||
|
hf_config = self.runner.model_config.hf_config
|
||||||
|
token_ids = seq_data.get_token_ids()
|
||||||
|
|
||||||
|
mrope_positions, mrope_position_delta = \
|
||||||
|
MRotaryEmbedding.get_input_positions(
|
||||||
|
token_ids,
|
||||||
|
hf_config=hf_config,
|
||||||
|
image_grid_thw=image_grid_thw,
|
||||||
|
video_grid_thw=video_grid_thw,
|
||||||
|
second_per_grid_ts=second_per_grid_ts,
|
||||||
|
context_len=computed_len,
|
||||||
|
audio_feature_lengths=audio_feature_lengths,
|
||||||
|
use_audio_in_video=use_audio_in_video,
|
||||||
|
)
|
||||||
|
seq_data.mrope_position_delta = mrope_position_delta
|
||||||
|
|
||||||
|
for i in range(3):
|
||||||
|
self.input_data.input_mrope_positions[ # type: ignore
|
||||||
|
i].extend(mrope_positions[i])
|
||||||
|
|
||||||
|
self.input_data.multi_modal_inputs_list.append(mm_kwargs)
|
||||||
|
for modality, placeholder_map in placeholder_maps.items():
|
||||||
|
self.input_data.multi_modal_placeholder_maps[modality].extend(
|
||||||
|
placeholder_map)
|
||||||
|
|
||||||
|
def _prepare_lora_input(
|
||||||
|
self, seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
is_prefill: bool) -> LoRAMapping:
|
||||||
|
index_mapping = []
|
||||||
|
prompt_mapping = []
|
||||||
|
for seq in seq_group_metadata_list:
|
||||||
|
lora_id = seq.lora_int_id
|
||||||
|
query_len = seq.token_chunk_size
|
||||||
|
|
||||||
|
index_mapping += [lora_id] * query_len
|
||||||
|
prompt_mapping += [lora_id] * (
|
||||||
|
query_len if seq.sampling_params
|
||||||
|
and seq.sampling_params.prompt_logprobs is not None else 1)
|
||||||
|
|
||||||
|
return LoRAMapping(index_mapping=tuple(index_mapping),
|
||||||
|
prompt_mapping=tuple(prompt_mapping),
|
||||||
|
is_prefill=is_prefill)
|
||||||
|
|
||||||
|
|
||||||
|
class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]):
|
||||||
|
"""
|
||||||
|
Helper class for shared methods between CPU model runners.
|
||||||
|
"""
|
||||||
|
_model_input_cls: Type[TModelInputForCPU]
|
||||||
|
_builder_cls: Type[ModelInputForCPUBuilder]
|
||||||
|
builder: ModelInputForCPUBuilder
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
kv_cache_dtype: Optional[str] = "auto",
|
||||||
|
is_driver_worker: bool = False,
|
||||||
|
return_hidden_states: bool = False,
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
ModelRunnerBase.__init__(self, vllm_config)
|
||||||
|
model_config = self.model_config
|
||||||
|
cache_config = self.cache_config
|
||||||
|
|
||||||
|
self.is_driver_worker = is_driver_worker
|
||||||
|
self.return_hidden_states = return_hidden_states
|
||||||
|
|
||||||
|
self.device = self.device_config.device
|
||||||
|
self.pin_memory = False
|
||||||
|
|
||||||
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
|
self.sliding_window = model_config.get_sliding_window()
|
||||||
|
self.block_size = cache_config.block_size
|
||||||
|
num_attn_heads = self.model_config.get_num_attention_heads(
|
||||||
|
self.parallel_config)
|
||||||
|
needs_attn_backend = (num_attn_heads != 0
|
||||||
|
or self.model_config.is_attention_free)
|
||||||
|
self.attn_backend = get_attn_backend(
|
||||||
|
self.model_config.get_head_size(),
|
||||||
|
self.model_config.dtype,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
self.block_size,
|
||||||
|
self.model_config.is_attention_free,
|
||||||
|
use_mla=self.model_config.use_mla,
|
||||||
|
) if needs_attn_backend else None
|
||||||
|
|
||||||
|
# Lazy initialization.
|
||||||
|
self.model: nn.Module # Set after init_Model
|
||||||
|
# Set after load_model.
|
||||||
|
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
|
||||||
|
self.sampler = get_sampler()
|
||||||
|
|
||||||
|
if hasattr(self, "_builder_cls"):
|
||||||
|
# multi-step model runner does not have `_builder_cls`
|
||||||
|
self.builder = self._builder_cls(weakref.proxy(self))
|
||||||
|
|
||||||
|
def load_model(self) -> None:
|
||||||
|
self.model = get_model(vllm_config=self.vllm_config)
|
||||||
|
|
||||||
|
if self.lora_config:
|
||||||
|
assert supports_lora(
|
||||||
|
self.model
|
||||||
|
), f"{self.model.__class__.__name__} does not support LoRA yet."
|
||||||
|
|
||||||
|
if supports_multimodal(self.model):
|
||||||
|
logger.warning("Regarding multimodal models, vLLM currently "
|
||||||
|
"only supports adding LoRA to language model.")
|
||||||
|
|
||||||
|
# Use get_text_config() in case of multimodal models
|
||||||
|
text_config = self.model_config.hf_config.get_text_config()
|
||||||
|
|
||||||
|
self.lora_manager = LRUCacheWorkerLoRAManager(
|
||||||
|
self.scheduler_config.max_num_seqs,
|
||||||
|
self.scheduler_config.max_num_batched_tokens,
|
||||||
|
self.vocab_size,
|
||||||
|
self.lora_config,
|
||||||
|
self.device,
|
||||||
|
self.model.embedding_modules,
|
||||||
|
self.model.embedding_padding_modules,
|
||||||
|
max_position_embeddings=text_config.max_position_embeddings,
|
||||||
|
)
|
||||||
|
self.model = self.lora_manager.create_lora_manager(self.model)
|
||||||
|
|
||||||
|
def get_model(self) -> nn.Module:
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
def _prepare_model_input_tensors(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
finished_requests_ids: Optional[List[str]] = None
|
||||||
|
) -> TModelInputForCPU:
|
||||||
|
"""Helper method to prepare the model input based on a given sequence
|
||||||
|
group. Prepares metadata needed for the base model forward pass but not
|
||||||
|
metadata for possible additional steps, e.g., sampling.
|
||||||
|
|
||||||
|
"""
|
||||||
|
self.builder.prepare(finished_requests_ids)
|
||||||
|
self.builder.set_seq_group_list(seq_group_metadata_list)
|
||||||
|
|
||||||
|
return self.builder.build() # type: ignore
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self) -> int:
|
||||||
|
return self.model_config.get_vocab_size()
|
||||||
|
|
||||||
|
def remove_all_loras(self):
|
||||||
|
if not self.lora_manager:
|
||||||
|
raise RuntimeError("LoRA is not enabled.")
|
||||||
|
self.lora_manager.remove_all_adapters()
|
||||||
|
|
||||||
|
def set_active_loras(self, lora_requests: Set[LoRARequest],
|
||||||
|
lora_mapping: LoRAMapping) -> None:
|
||||||
|
if not self.lora_manager:
|
||||||
|
raise RuntimeError("LoRA is not enabled.")
|
||||||
|
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
|
||||||
|
|
||||||
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
|
if not self.lora_manager:
|
||||||
|
raise RuntimeError("LoRA is not enabled.")
|
||||||
|
return self.lora_manager.add_adapter(lora_request)
|
||||||
|
|
||||||
|
def remove_lora(self, lora_id: int) -> bool:
|
||||||
|
if not self.lora_manager:
|
||||||
|
raise RuntimeError("LoRA is not enabled.")
|
||||||
|
return self.lora_manager.remove_adapter(lora_id)
|
||||||
|
|
||||||
|
def pin_lora(self, lora_id: int) -> bool:
|
||||||
|
if not self.lora_manager:
|
||||||
|
raise RuntimeError("LoRA is not enabled.")
|
||||||
|
return self.lora_manager.pin_adapter(lora_id)
|
||||||
|
|
||||||
|
def list_loras(self) -> Set[int]:
|
||||||
|
if not self.lora_manager:
|
||||||
|
raise RuntimeError("LoRA is not enabled.")
|
||||||
|
return self.lora_manager.list_adapters()
|
||||||
|
|
||||||
|
|
||||||
|
class CPUModelRunner(CPUModelRunnerBase[ModelInputForCPUWithSamplingMetadata]):
|
||||||
|
_model_input_cls: Type[ModelInputForCPUWithSamplingMetadata] = (
|
||||||
|
ModelInputForCPUWithSamplingMetadata)
|
||||||
|
_builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder
|
||||||
|
|
||||||
|
def make_model_input_from_broadcasted_tensor_dict(
|
||||||
|
self,
|
||||||
|
tensor_dict: Dict[str, Any],
|
||||||
|
) -> ModelInputForCPUWithSamplingMetadata:
|
||||||
|
return ModelInputForCPUWithSamplingMetadata.from_broadcasted_tensor_dict( # noqa: E501
|
||||||
|
tensor_dict,
|
||||||
|
attn_backend=self.attn_backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_model_input(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
virtual_engine: int = 0,
|
||||||
|
finished_requests_ids: Optional[List[str]] = None
|
||||||
|
) -> ModelInputForCPUWithSamplingMetadata:
|
||||||
|
"""Prepare the model input based on a given sequence group, including
|
||||||
|
metadata for the sampling step.
|
||||||
|
|
||||||
|
"""
|
||||||
|
model_input = self._prepare_model_input_tensors(
|
||||||
|
seq_group_metadata_list, finished_requests_ids)
|
||||||
|
# Sampling metadata is only required for the final pp group
|
||||||
|
generators = self.get_generators(finished_requests_ids)
|
||||||
|
sampling_metadata = SamplingMetadata.prepare(seq_group_metadata_list,
|
||||||
|
model_input.seq_lens,
|
||||||
|
model_input.query_lens,
|
||||||
|
self.device,
|
||||||
|
pin_memory=False,
|
||||||
|
generators=generators)
|
||||||
|
|
||||||
|
is_prompt = (seq_group_metadata_list[0].is_prompt
|
||||||
|
if seq_group_metadata_list else None)
|
||||||
|
return dataclasses.replace(model_input,
|
||||||
|
sampling_metadata=sampling_metadata,
|
||||||
|
virtual_engine=virtual_engine,
|
||||||
|
is_prompt=is_prompt)
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
model_input: ModelInputForCPUWithSamplingMetadata,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
num_steps: int = 1,
|
||||||
|
previous_hidden_states: Optional[torch.Tensor] = None,
|
||||||
|
) -> Optional[List[SamplerOutput]]:
|
||||||
|
if num_steps > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"CPU worker does not support multi-step execution.")
|
||||||
|
|
||||||
|
if self.lora_config:
|
||||||
|
assert model_input.lora_requests is not None
|
||||||
|
assert model_input.lora_mapping is not None
|
||||||
|
self.set_active_loras(model_input.lora_requests,
|
||||||
|
model_input.lora_mapping)
|
||||||
|
|
||||||
|
model_executable = self.model
|
||||||
|
|
||||||
|
multimodal_kwargs = {}
|
||||||
|
if model_input.multi_modal_kwargs is not None:
|
||||||
|
multimodal_kwargs = MultiModalKwargs.as_kwargs(
|
||||||
|
model_input.multi_modal_kwargs,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
execute_model_kwargs = {}
|
||||||
|
if previous_hidden_states is not None:
|
||||||
|
execute_model_kwargs.update(
|
||||||
|
{"previous_hidden_states": previous_hidden_states})
|
||||||
|
|
||||||
|
with set_forward_context(model_input.attn_metadata, self.vllm_config,
|
||||||
|
model_input.virtual_engine):
|
||||||
|
hidden_states = model_executable(
|
||||||
|
input_ids=model_input.input_tokens,
|
||||||
|
positions=model_input.input_positions,
|
||||||
|
intermediate_tensors=intermediate_tensors,
|
||||||
|
**execute_model_kwargs,
|
||||||
|
**multimodal_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Compute the logits.
|
||||||
|
logits = self.model.compute_logits(hidden_states,
|
||||||
|
model_input.sampling_metadata)
|
||||||
|
|
||||||
|
# Only perform sampling in the driver worker.
|
||||||
|
if not self.is_driver_worker:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Sample the next token.
|
||||||
|
output = self.sampler(
|
||||||
|
logits=logits,
|
||||||
|
sampling_metadata=model_input.sampling_metadata,
|
||||||
|
)
|
||||||
|
if self.return_hidden_states:
|
||||||
|
# we only need to pass hidden states of most recent token
|
||||||
|
if model_input.is_prompt:
|
||||||
|
output.prefill_hidden_states = hidden_states
|
||||||
|
output.hidden_states = hidden_states
|
||||||
|
return [output]
|
||||||
|
|
||||||
|
def generate_proposals(self, *args, **kwargs):
|
||||||
|
return self.model.generate_proposals(*args, **kwargs)
|
|
@ -0,0 +1,125 @@
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.forward_context import set_forward_context
|
||||||
|
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||||
|
from vllm.multimodal import MultiModalKwargs
|
||||||
|
from vllm.pooling_params import PoolingParams
|
||||||
|
from vllm.sequence import (IntermediateTensors, PoolerOutput, SequenceData,
|
||||||
|
SequenceGroupMetadata)
|
||||||
|
from vllm.worker.cpu_model_runner import (CPUModelRunnerBase, ModelInputForCPU,
|
||||||
|
ModelInputForCPUBuilder)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
|
class ModelInputForCPUWithPoolingMetadata(ModelInputForCPU):
|
||||||
|
"""
|
||||||
|
Used by the CPUPoolingModelRunner.
|
||||||
|
"""
|
||||||
|
pooling_metadata: Optional["PoolingMetadata"] = None
|
||||||
|
|
||||||
|
|
||||||
|
class CPUPoolingModelRunner(
|
||||||
|
CPUModelRunnerBase[ModelInputForCPUWithPoolingMetadata]):
|
||||||
|
_model_input_cls: Type[ModelInputForCPUWithPoolingMetadata] = (
|
||||||
|
ModelInputForCPUWithPoolingMetadata)
|
||||||
|
_builder_cls: Type[ModelInputForCPUBuilder] = ModelInputForCPUBuilder
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
model_input: ModelInputForCPUWithPoolingMetadata,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
num_steps: int = 1,
|
||||||
|
) -> Optional[Union[List[PoolerOutput], IntermediateTensors]]:
|
||||||
|
if num_steps > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"CPU worker does not support multi-step execution.")
|
||||||
|
|
||||||
|
model_executable = self.model
|
||||||
|
cross_enc_kwargs = {}
|
||||||
|
if model_input.token_type_ids is not None:
|
||||||
|
cross_enc_kwargs["token_type_ids"] = model_input.token_type_ids
|
||||||
|
execute_model_kwargs = {
|
||||||
|
"input_ids":
|
||||||
|
model_input.input_tokens,
|
||||||
|
"positions":
|
||||||
|
model_input.input_positions,
|
||||||
|
**MultiModalKwargs.as_kwargs(
|
||||||
|
model_input.multi_modal_kwargs or {},
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
**cross_enc_kwargs,
|
||||||
|
"intermediate_tensors":
|
||||||
|
intermediate_tensors,
|
||||||
|
}
|
||||||
|
|
||||||
|
with set_forward_context(model_input.attn_metadata, self.vllm_config,
|
||||||
|
model_input.virtual_engine):
|
||||||
|
hidden_states = model_executable(**execute_model_kwargs)
|
||||||
|
|
||||||
|
# Only perform pooling in the driver worker.
|
||||||
|
if not self.is_driver_worker:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return [
|
||||||
|
self.model.pooler(hidden_states=hidden_states,
|
||||||
|
pooling_metadata=model_input.pooling_metadata)
|
||||||
|
]
|
||||||
|
|
||||||
|
def make_model_input_from_broadcasted_tensor_dict(
|
||||||
|
self,
|
||||||
|
tensor_dict: Dict[str,
|
||||||
|
Any]) -> ModelInputForCPUWithPoolingMetadata:
|
||||||
|
return ModelInputForCPUWithPoolingMetadata.from_broadcasted_tensor_dict(
|
||||||
|
tensor_dict,
|
||||||
|
attn_backend=self.attn_backend,
|
||||||
|
)
|
||||||
|
|
||||||
|
def prepare_model_input(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
|
||||||
|
virtual_engine: int = 0,
|
||||||
|
finished_requests_ids: Optional[List[str]] = None
|
||||||
|
) -> ModelInputForCPUWithPoolingMetadata:
|
||||||
|
assert seq_group_metadata_list is not None
|
||||||
|
model_input = self._prepare_model_input_tensors(
|
||||||
|
seq_group_metadata_list, finished_requests_ids)
|
||||||
|
# Prepare PoolingMetadata.
|
||||||
|
assert model_input.seq_lens is not None
|
||||||
|
pooling_metadata = self._prepare_pooling(seq_group_metadata_list,
|
||||||
|
model_input.seq_lens)
|
||||||
|
|
||||||
|
return dataclasses.replace(model_input,
|
||||||
|
virtual_engine=virtual_engine,
|
||||||
|
pooling_metadata=pooling_metadata)
|
||||||
|
|
||||||
|
def _prepare_pooling(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
prompt_lens: List[int],
|
||||||
|
) -> PoolingMetadata:
|
||||||
|
"""Prepare PoolingMetadata for the sequence group metadata list."""
|
||||||
|
seq_groups: List[Tuple[List[int], PoolingParams]] = []
|
||||||
|
for i, seq_group_metadata in enumerate(seq_group_metadata_list):
|
||||||
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||||
|
pooling_params = seq_group_metadata.pooling_params
|
||||||
|
seq_groups.append((seq_ids, pooling_params))
|
||||||
|
|
||||||
|
seq_data: Dict[int, SequenceData] = {}
|
||||||
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
|
seq_data.update(seq_group_metadata.seq_data)
|
||||||
|
|
||||||
|
pooling_metadata = PoolingMetadata(
|
||||||
|
seq_groups=seq_groups,
|
||||||
|
seq_data=seq_data,
|
||||||
|
prompt_lens=prompt_lens,
|
||||||
|
)
|
||||||
|
|
||||||
|
return pooling_metadata
|
|
@ -0,0 +1,452 @@
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""A CPU worker class."""
|
||||||
|
import os
|
||||||
|
from importlib import util
|
||||||
|
from typing import List, Optional, Set, Tuple, Type
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.attention import get_attn_backend
|
||||||
|
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
|
||||||
|
ParallelConfig, VllmConfig)
|
||||||
|
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||||
|
init_distributed_environment)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.lora.request import LoRARequest
|
||||||
|
from vllm.model_executor import set_random_seed
|
||||||
|
from vllm.sequence import ExecuteModelRequest
|
||||||
|
from vllm.utils import bind_kv_cache
|
||||||
|
from vllm.worker.cpu_enc_dec_model_runner import CPUEncoderDecoderModelRunner
|
||||||
|
from vllm.worker.cpu_model_runner import CPUModelRunner, CPUModelRunnerBase
|
||||||
|
from vllm.worker.cpu_pooling_model_runner import CPUPoolingModelRunner
|
||||||
|
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
|
||||||
|
WorkerInput)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CPUCacheEngine:
|
||||||
|
"""Manages the KV cache for CPU backend.
|
||||||
|
|
||||||
|
This class is responsible for initializing and managing CPU KV
|
||||||
|
caches. It also provides methods for performing KV cache operations, such
|
||||||
|
as copying.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,
|
||||||
|
parallel_config: ParallelConfig,
|
||||||
|
device_config: DeviceConfig) -> None:
|
||||||
|
assert device_config.device_type == "cpu"
|
||||||
|
self.cache_config = cache_config
|
||||||
|
self.model_config = model_config
|
||||||
|
self.parallel_config = parallel_config
|
||||||
|
|
||||||
|
self.head_size = model_config.get_head_size()
|
||||||
|
self.num_layers = model_config.get_num_layers(parallel_config)
|
||||||
|
self.num_heads = model_config.get_num_kv_heads(parallel_config)
|
||||||
|
|
||||||
|
self.block_size = cache_config.block_size
|
||||||
|
# Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks
|
||||||
|
# for CPU backend, because we want to reuse KV cache management
|
||||||
|
# in the scheduler.
|
||||||
|
self.num_cpu_blocks = cache_config.num_gpu_blocks
|
||||||
|
|
||||||
|
self.dtype = CPUCacheEngine.get_kv_cache_dtype(cache_config,
|
||||||
|
model_config)
|
||||||
|
|
||||||
|
# Get attention backend.
|
||||||
|
self.attn_backend = get_attn_backend(
|
||||||
|
self.model_config.get_head_size(),
|
||||||
|
self.model_config.dtype,
|
||||||
|
cache_config.cache_dtype,
|
||||||
|
self.block_size,
|
||||||
|
self.model_config.is_attention_free,
|
||||||
|
use_mla=self.model_config.use_mla,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize the cache.
|
||||||
|
self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks)
|
||||||
|
|
||||||
|
def _allocate_kv_cache(
|
||||||
|
self,
|
||||||
|
num_blocks: int,
|
||||||
|
) -> List[torch.Tensor]:
|
||||||
|
"""Allocates KV cache on CPU."""
|
||||||
|
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||||
|
num_blocks, self.block_size, self.num_heads, self.head_size)
|
||||||
|
kv_cache: List[torch.Tensor] = []
|
||||||
|
for _ in range(self.num_layers):
|
||||||
|
kv_cache.append(
|
||||||
|
torch.empty(kv_cache_shape, dtype=self.dtype, device="cpu"))
|
||||||
|
return kv_cache
|
||||||
|
|
||||||
|
def swap_in(self, src_to_dst: torch.Tensor) -> None:
|
||||||
|
raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
|
||||||
|
|
||||||
|
def swap_out(self, src_to_dst: torch.Tensor) -> None:
|
||||||
|
raise NotImplementedError("Swap is not supported in CPUCacheEngine.")
|
||||||
|
|
||||||
|
def copy(self, src_to_dsts: torch.Tensor) -> None:
|
||||||
|
self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_kv_cache_dtype(cache_config: CacheConfig,
|
||||||
|
model_config: ModelConfig):
|
||||||
|
if cache_config.cache_dtype == "auto":
|
||||||
|
return model_config.dtype
|
||||||
|
elif cache_config.cache_dtype in ["fp8", "fp8_e5m2"]:
|
||||||
|
return torch.float8_e5m2
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unsupported KV cache type "
|
||||||
|
f"{cache_config.cache_dtype}.")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_cache_block_size(
|
||||||
|
cache_config: CacheConfig,
|
||||||
|
model_config: ModelConfig,
|
||||||
|
parallel_config: ParallelConfig,
|
||||||
|
) -> int:
|
||||||
|
head_size = model_config.get_head_size()
|
||||||
|
num_heads = model_config.get_num_kv_heads(parallel_config)
|
||||||
|
num_layers = model_config.get_num_layers(parallel_config)
|
||||||
|
|
||||||
|
key_cache_block = cache_config.block_size * num_heads * head_size
|
||||||
|
value_cache_block = key_cache_block if not model_config.use_mla else 0
|
||||||
|
total = num_layers * (key_cache_block + value_cache_block)
|
||||||
|
dtype = CPUCacheEngine.get_kv_cache_dtype(cache_config, model_config)
|
||||||
|
dtype_size = torch.tensor([], dtype=dtype).element_size()
|
||||||
|
return dtype_size * total
|
||||||
|
|
||||||
|
|
||||||
|
class CPUWorker(LocalOrDistributedWorkerBase):
|
||||||
|
"""A worker class that executes (a partition of) the model on a CPU socket.
|
||||||
|
|
||||||
|
Each worker is associated with a single CPU socket. The worker is
|
||||||
|
responsible for maintaining the KV cache and executing the model on the
|
||||||
|
CPU. In case of distributed inference, each worker is assigned a partition
|
||||||
|
of the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
local_rank: int,
|
||||||
|
rank: int,
|
||||||
|
distributed_init_method: str,
|
||||||
|
kv_cache_dtype: Optional[str] = "auto",
|
||||||
|
is_driver_worker: bool = False,
|
||||||
|
model_runner_cls: Optional[Type[CPUModelRunner]] = None,
|
||||||
|
) -> None:
|
||||||
|
WorkerBase.__init__(self, vllm_config=vllm_config)
|
||||||
|
|
||||||
|
self.local_rank = local_rank
|
||||||
|
self.rank = rank
|
||||||
|
vllm_config.parallel_config.rank = rank
|
||||||
|
|
||||||
|
self.distributed_init_method = distributed_init_method
|
||||||
|
|
||||||
|
self.is_driver_worker = is_driver_worker
|
||||||
|
if self.is_driver_worker:
|
||||||
|
assert self.rank == 0, "The driver worker must have rank 0."
|
||||||
|
|
||||||
|
if self.model_config.trust_remote_code:
|
||||||
|
# note: lazy import to avoid importing torch before initializing
|
||||||
|
from vllm.utils import init_cached_hf_modules
|
||||||
|
init_cached_hf_modules()
|
||||||
|
|
||||||
|
# Setup OpenMP threads affinity.
|
||||||
|
omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND
|
||||||
|
self.local_omp_cpuid = "all"
|
||||||
|
if omp_cpuids == "auto":
|
||||||
|
self.local_omp_cpuid = self.get_cpus_id_binding_based_on_numa_nodes(
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.local_omp_cpuid = omp_cpuids.split("|")[rank]
|
||||||
|
|
||||||
|
# Return hidden states from target model if the draft model is an
|
||||||
|
# mlp_speculator
|
||||||
|
speculative_config = self.speculative_config
|
||||||
|
model_config = self.model_config
|
||||||
|
speculative_args = {} if speculative_config is None \
|
||||||
|
or (speculative_config.draft_model_config.model ==
|
||||||
|
model_config.model) \
|
||||||
|
or (speculative_config.draft_model_config.hf_config.model_type
|
||||||
|
not in ["medusa", "mlp_speculator", "eagle"]) \
|
||||||
|
else {"return_hidden_states": True}
|
||||||
|
ModelRunnerClass: Type[CPUModelRunnerBase] = CPUModelRunner
|
||||||
|
if self.model_config.runner_type == "pooling":
|
||||||
|
ModelRunnerClass = CPUPoolingModelRunner
|
||||||
|
elif self.model_config.is_encoder_decoder:
|
||||||
|
ModelRunnerClass = CPUEncoderDecoderModelRunner
|
||||||
|
self.model_runner: CPUModelRunnerBase = ModelRunnerClass(
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
is_driver_worker=is_driver_worker,
|
||||||
|
**speculative_args,
|
||||||
|
)
|
||||||
|
if model_runner_cls is not None:
|
||||||
|
self.model_runner = model_runner_cls(self.model_runner)
|
||||||
|
# Uninitialized cache engine. Will be initialized by
|
||||||
|
# initialize_cache.
|
||||||
|
self.cache_engine: List[CPUCacheEngine]
|
||||||
|
# Initialize cpu_cache as pooling models don't initialize kv_caches
|
||||||
|
self.cpu_cache: Optional[List[List[torch.Tensor]]] = None
|
||||||
|
|
||||||
|
# Torch profiler. Enabled and configured through env vars:
|
||||||
|
# VLLM_TORCH_PROFILER_DIR=/path/to/save/trace
|
||||||
|
if envs.VLLM_TORCH_PROFILER_DIR:
|
||||||
|
torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||||
|
logger.info("Profiling enabled. Traces will be saved to: %s",
|
||||||
|
torch_profiler_trace_dir)
|
||||||
|
self.profiler = torch.profiler.profile(
|
||||||
|
activities=[
|
||||||
|
torch.profiler.ProfilerActivity.CPU,
|
||||||
|
],
|
||||||
|
with_stack=True,
|
||||||
|
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||||
|
torch_profiler_trace_dir, use_gzip=True))
|
||||||
|
else:
|
||||||
|
self.profiler = None
|
||||||
|
|
||||||
|
def start_profile(self):
|
||||||
|
if self.profiler is None:
|
||||||
|
raise RuntimeError("Profiler is not enabled.")
|
||||||
|
self.profiler.start()
|
||||||
|
|
||||||
|
def stop_profile(self):
|
||||||
|
if self.profiler is None:
|
||||||
|
raise RuntimeError("Profiler is not enabled.")
|
||||||
|
self.profiler.stop()
|
||||||
|
|
||||||
|
def init_device(self) -> None:
|
||||||
|
if self.local_omp_cpuid != "all":
|
||||||
|
ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid)
|
||||||
|
if ret:
|
||||||
|
logger.info(ret)
|
||||||
|
|
||||||
|
# Note: unique identifier for creating allreduce shared memory
|
||||||
|
os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split(
|
||||||
|
":")[-1]
|
||||||
|
self.device = torch.device("cpu")
|
||||||
|
self.init_distributed_environment()
|
||||||
|
# Set random seed.
|
||||||
|
set_random_seed(self.model_config.seed)
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
self.model_runner.load_model()
|
||||||
|
|
||||||
|
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||||
|
"""Determine the number of blocks available for the KV cache.
|
||||||
|
|
||||||
|
This determines how many KV blocks can fit into the configured CPU
|
||||||
|
KV cache space.
|
||||||
|
|
||||||
|
Note that since vLLM assumes a block resides on GPU if it can be
|
||||||
|
modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0.
|
||||||
|
This allows us to reuse the scheduler of vLLM without generalizing it
|
||||||
|
to different devices.
|
||||||
|
"""
|
||||||
|
# For CPU device, the block number will be calculated based on the
|
||||||
|
# cpu_kvcache_space.
|
||||||
|
cache_block_size = self.get_cache_block_size_bytes()
|
||||||
|
num_cpu_blocks = int(self.cache_config.cpu_kvcache_space_bytes //
|
||||||
|
cache_block_size)
|
||||||
|
num_cpu_blocks = max(num_cpu_blocks, 0)
|
||||||
|
|
||||||
|
# Note: To reuse the cache management procedure,
|
||||||
|
# use cpu cache as 'gpu cache'.
|
||||||
|
num_gpu_blocks = num_cpu_blocks
|
||||||
|
num_cpu_blocks = 0
|
||||||
|
return num_gpu_blocks, num_cpu_blocks
|
||||||
|
|
||||||
|
def initialize_cache(self, num_gpu_blocks: int,
|
||||||
|
num_cpu_blocks: int) -> None:
|
||||||
|
"""Initialize the KV cache. Currently, swappable CPU memory is not
|
||||||
|
supported.
|
||||||
|
|
||||||
|
Since this worker does not support GPUs, we use the num_gpu_blocks to
|
||||||
|
determine how many non-swappable CPU blocks to allocate.
|
||||||
|
"""
|
||||||
|
assert (num_cpu_blocks == 0
|
||||||
|
), f"{type(self)} does not support swappable cache"
|
||||||
|
|
||||||
|
# Note: To reuse the cache management procedure,
|
||||||
|
# use cpu cache as 'gpu cache'.
|
||||||
|
num_cpu_blocks = num_gpu_blocks
|
||||||
|
|
||||||
|
self._validate_num_cpu_blocks(num_cpu_blocks)
|
||||||
|
self.cache_config.num_gpu_blocks = num_cpu_blocks
|
||||||
|
self.cache_config.num_cpu_blocks = 0
|
||||||
|
|
||||||
|
# Initialize the cache.
|
||||||
|
self._init_cache_engine()
|
||||||
|
|
||||||
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||||
|
return self.model_runner.add_lora(lora_request)
|
||||||
|
|
||||||
|
def remove_lora(self, lora_id: int) -> bool:
|
||||||
|
return self.model_runner.remove_lora(lora_id)
|
||||||
|
|
||||||
|
def pin_lora(self, lora_id: int) -> bool:
|
||||||
|
return self.model_runner.pin_lora(lora_id)
|
||||||
|
|
||||||
|
def list_loras(self) -> Set[int]:
|
||||||
|
return self.model_runner.list_loras()
|
||||||
|
|
||||||
|
def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None:
|
||||||
|
"""Raise errors if the num_cpu_blocks is invalid.
|
||||||
|
"""
|
||||||
|
if num_cpu_blocks <= 0:
|
||||||
|
raise ValueError("No available memory for the cache blocks. "
|
||||||
|
"Try increasing `VLLM_CPU_KVCACHE_SPACE` when "
|
||||||
|
"initializing the engine.")
|
||||||
|
|
||||||
|
max_seq_len = self.cache_config.block_size * num_cpu_blocks
|
||||||
|
if self.model_config.max_model_len > max_seq_len:
|
||||||
|
raise ValueError(
|
||||||
|
f"The model's max seq len ({self.model_config.max_model_len}) "
|
||||||
|
"is larger than the maximum number of tokens that can be "
|
||||||
|
f"stored in KV cache ({max_seq_len}). Try increasing "
|
||||||
|
"`VLLM_CPU_KVCACHE_SPACE` or decreasing `max_model_len` when "
|
||||||
|
"initializing the engine.")
|
||||||
|
|
||||||
|
def _init_cache_engine(self) -> None:
|
||||||
|
self.cache_engine = [
|
||||||
|
CPUCacheEngine(self.cache_config, self.model_config,
|
||||||
|
self.parallel_config, self.device_config)
|
||||||
|
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||||
|
]
|
||||||
|
self.cpu_cache = [
|
||||||
|
self.cache_engine[ve].cpu_cache
|
||||||
|
for ve in range(self.parallel_config.pipeline_parallel_size)
|
||||||
|
]
|
||||||
|
bind_kv_cache(self.compilation_config.static_forward_context,
|
||||||
|
self.cpu_cache)
|
||||||
|
self.model_runner.block_size = self.cache_engine[0].block_size
|
||||||
|
|
||||||
|
assert all(
|
||||||
|
self.cpu_cache[ve] is not None
|
||||||
|
for ve in range(self.parallel_config.pipeline_parallel_size))
|
||||||
|
|
||||||
|
# Populate the cache to warmup the memory
|
||||||
|
for ve in range(self.parallel_config.pipeline_parallel_size):
|
||||||
|
for layer_cache in self.cpu_cache[ve]:
|
||||||
|
layer_cache.fill_(0)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def do_metadata_broadcast(self) -> bool:
|
||||||
|
return self.parallel_config.tensor_parallel_size > 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
|
||||||
|
return self.cpu_cache
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self) -> int:
|
||||||
|
return self.model_runner.vocab_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_model_len(self) -> int:
|
||||||
|
return self.model_config.max_model_len
|
||||||
|
|
||||||
|
def execute_worker(
|
||||||
|
self,
|
||||||
|
worker_input: WorkerInput,
|
||||||
|
) -> None:
|
||||||
|
if (worker_input.blocks_to_copy is not None
|
||||||
|
and worker_input.blocks_to_copy.numel() > 0):
|
||||||
|
self.cache_engine[worker_input.virtual_engine].copy(
|
||||||
|
worker_input.blocks_to_copy)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def prepare_worker_input(
|
||||||
|
self, execute_model_req: ExecuteModelRequest) -> WorkerInput:
|
||||||
|
assert execute_model_req is not None
|
||||||
|
virtual_engine: int = execute_model_req.virtual_engine
|
||||||
|
num_seq_groups: int = len(execute_model_req.seq_group_metadata_list)
|
||||||
|
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
|
||||||
|
device="cpu",
|
||||||
|
dtype=torch.int64).view(-1, 2)
|
||||||
|
assert len(execute_model_req.blocks_to_swap_in) == 0
|
||||||
|
assert len(execute_model_req.blocks_to_swap_out) == 0
|
||||||
|
return WorkerInput(
|
||||||
|
num_seq_groups=num_seq_groups,
|
||||||
|
blocks_to_copy=blocks_to_copy,
|
||||||
|
virtual_engine=virtual_engine,
|
||||||
|
)
|
||||||
|
|
||||||
|
def init_distributed_environment(self) -> None:
|
||||||
|
"""Initialize the distributed environment."""
|
||||||
|
|
||||||
|
parallel_config = self.parallel_config
|
||||||
|
rank = self.rank
|
||||||
|
distributed_init_method = self.distributed_init_method
|
||||||
|
init_distributed_environment(
|
||||||
|
world_size=parallel_config.world_size,
|
||||||
|
rank=rank,
|
||||||
|
distributed_init_method=distributed_init_method,
|
||||||
|
backend="gloo",
|
||||||
|
)
|
||||||
|
|
||||||
|
# A small all_reduce for warmup.
|
||||||
|
torch.distributed.all_reduce(torch.zeros(1).cpu())
|
||||||
|
|
||||||
|
ensure_model_parallel_initialized(
|
||||||
|
parallel_config.tensor_parallel_size,
|
||||||
|
parallel_config.pipeline_parallel_size)
|
||||||
|
|
||||||
|
def get_cache_block_size_bytes(self) -> int:
|
||||||
|
"""Return the size in bytes of a single KV cache block.
|
||||||
|
"""
|
||||||
|
return CPUCacheEngine.get_cache_block_size(self.cache_config,
|
||||||
|
self.model_config,
|
||||||
|
self.parallel_config)
|
||||||
|
|
||||||
|
def get_cpus_id_binding_based_on_numa_nodes(self) -> str:
|
||||||
|
"""Return CPUs id binding based on NUMA nodes.
|
||||||
|
"""
|
||||||
|
rank_to_cpus = self.local_omp_cpuid
|
||||||
|
# Setup OpenMP thread affinity based on NUMA nodes automatically
|
||||||
|
world_size = self.vllm_config.parallel_config.world_size
|
||||||
|
libnuma_found = util.find_spec("numa") is not None
|
||||||
|
psutil_found = util.find_spec("psutil") is not None
|
||||||
|
if libnuma_found and psutil_found:
|
||||||
|
import psutil
|
||||||
|
from numa import info
|
||||||
|
cpu_count = psutil.cpu_count(logical=False)
|
||||||
|
cpus_allow_list = psutil.Process().cpu_affinity()
|
||||||
|
numa_size = info.get_num_configured_nodes()
|
||||||
|
cpu_count_per_numa = cpu_count // numa_size
|
||||||
|
num_of_reserved_cpu = min(envs.VLLM_CPU_NUM_OF_RESERVED_CPU,
|
||||||
|
cpu_count_per_numa // 2)
|
||||||
|
|
||||||
|
# check allow node_to_cpus list
|
||||||
|
node_to_cpus = []
|
||||||
|
for i in range(numa_size):
|
||||||
|
node_intersect = set(
|
||||||
|
info.node_to_cpus(i)).intersection(cpus_allow_list)
|
||||||
|
if bool(node_intersect):
|
||||||
|
node_to_cpus.append(list(node_intersect))
|
||||||
|
|
||||||
|
if world_size > len(node_to_cpus):
|
||||||
|
logger.error(
|
||||||
|
"Auto thread-binding failed due to "
|
||||||
|
"world size: %d is larger than "
|
||||||
|
"allowed NUMA nodes number: %d."
|
||||||
|
"Please try to bind threads manually.", world_size,
|
||||||
|
len(node_to_cpus))
|
||||||
|
else:
|
||||||
|
end = cpu_count_per_numa - num_of_reserved_cpu
|
||||||
|
rank_to_cpus_list = node_to_cpus[self.rank][:end]
|
||||||
|
rank_to_cpus = ','.join(str(x) for x in rank_to_cpus_list)
|
||||||
|
logger.info("auto thread-binding list: %s", rank_to_cpus)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Auto thread-binding is not supported due to "
|
||||||
|
"the lack of package numa and psutil,"
|
||||||
|
"fallback to no thread-binding. To get better performance,"
|
||||||
|
"please try to manually bind threads.")
|
||||||
|
return rank_to_cpus
|
|
@ -0,0 +1,108 @@
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
from typing import Dict, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.distributed import broadcast_tensor_dict
|
||||||
|
from vllm.sequence import ExecuteModelRequest
|
||||||
|
from vllm.worker.tpu_model_runner import ModelInputForTPU
|
||||||
|
from vllm.worker.tpu_worker import TPUWorker
|
||||||
|
from vllm.worker.worker_base import WorkerInput
|
||||||
|
|
||||||
|
|
||||||
|
class MultiStepTPUWorker(TPUWorker):
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.cached_model_input: Optional[ModelInputForTPU] = None
|
||||||
|
|
||||||
|
def _get_driver_input_and_broadcast(
|
||||||
|
self, execute_model_req: ExecuteModelRequest
|
||||||
|
) -> Tuple[ModelInputForTPU, WorkerInput, Dict[str, torch.Tensor]]:
|
||||||
|
assert self.is_driver_worker
|
||||||
|
assert execute_model_req.virtual_engine == 0
|
||||||
|
|
||||||
|
is_first_multi_step = execute_model_req.is_first_multi_step
|
||||||
|
is_last_step = execute_model_req.is_last_step
|
||||||
|
if is_first_multi_step:
|
||||||
|
worker_input: WorkerInput = self.prepare_worker_input(
|
||||||
|
execute_model_req=execute_model_req)
|
||||||
|
worker_input = dataclasses.replace(
|
||||||
|
worker_input,
|
||||||
|
num_steps=execute_model_req.num_lookahead_slots + 1)
|
||||||
|
model_input: ModelInputForTPU = (
|
||||||
|
self.model_runner.prepare_model_input(
|
||||||
|
execute_model_req.seq_group_metadata_list,
|
||||||
|
execute_model_req.virtual_engine,
|
||||||
|
execute_model_req.finished_requests_ids))
|
||||||
|
|
||||||
|
if execute_model_req.async_callback:
|
||||||
|
model_input = dataclasses.replace(
|
||||||
|
model_input,
|
||||||
|
async_callback=execute_model_req.async_callback)
|
||||||
|
else:
|
||||||
|
assert self.cached_model_input is not None
|
||||||
|
model_input = self.cached_model_input
|
||||||
|
worker_input = WorkerInput()
|
||||||
|
model_input = dataclasses.replace(
|
||||||
|
model_input,
|
||||||
|
is_first_multi_step=is_first_multi_step,
|
||||||
|
is_last_step=is_last_step)
|
||||||
|
|
||||||
|
if self.do_metadata_broadcast:
|
||||||
|
if is_first_multi_step:
|
||||||
|
broadcast_data = worker_input.as_broadcastable_tensor_dict()
|
||||||
|
broadcast_data.update(
|
||||||
|
model_input.as_broadcastable_tensor_dict())
|
||||||
|
broadcast_tensor_dict(broadcast_data, src=0)
|
||||||
|
else:
|
||||||
|
broadcast_data = {
|
||||||
|
"is_first_multi_step": is_first_multi_step,
|
||||||
|
"is_last_step": is_last_step,
|
||||||
|
}
|
||||||
|
broadcast_tensor_dict(broadcast_data, src=0)
|
||||||
|
|
||||||
|
# Retuning empty dict here to keep this compatible with
|
||||||
|
# `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
|
||||||
|
return model_input, worker_input, {}
|
||||||
|
|
||||||
|
def prepare_input(
|
||||||
|
self,
|
||||||
|
execute_model_req: Optional[ExecuteModelRequest] = None,
|
||||||
|
) -> Optional[Tuple[ModelInputForTPU, WorkerInput, Dict[str,
|
||||||
|
torch.Tensor]]]:
|
||||||
|
if self.is_driver_worker:
|
||||||
|
if execute_model_req is None:
|
||||||
|
if self.do_metadata_broadcast:
|
||||||
|
broadcast_tensor_dict({}, src=0)
|
||||||
|
return None
|
||||||
|
|
||||||
|
model_input, worker_input, _ = self._get_driver_input_and_broadcast(
|
||||||
|
execute_model_req)
|
||||||
|
if model_input.is_first_multi_step:
|
||||||
|
self.cached_model_input = model_input
|
||||||
|
return model_input, worker_input, {}
|
||||||
|
else:
|
||||||
|
broadcast_data = broadcast_tensor_dict(src=0)
|
||||||
|
if not broadcast_data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if len(broadcast_data) == 2:
|
||||||
|
assert self.cached_model_input is not None
|
||||||
|
self.cached_model_input = dataclasses.replace(
|
||||||
|
self.cached_model_input,
|
||||||
|
is_first_multi_step=broadcast_data["is_first_multi_step"],
|
||||||
|
is_last_step=broadcast_data["is_last_step"])
|
||||||
|
empty_worker_input = WorkerInput()
|
||||||
|
return self.cached_model_input, empty_worker_input, {}
|
||||||
|
|
||||||
|
worker_input = WorkerInput.from_broadcasted_tensor_dict(
|
||||||
|
broadcast_data)
|
||||||
|
model_input = (
|
||||||
|
self.model_runner.
|
||||||
|
make_model_input_from_broadcasted_tensor_dict(broadcast_data))
|
||||||
|
self.cached_model_input = model_input
|
||||||
|
return model_input, worker_input, {}
|
|
@ -0,0 +1,909 @@
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import enum
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
|
||||||
|
Type, Union)
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch_xla.core.xla_model as xm
|
||||||
|
import torch_xla.runtime as xr
|
||||||
|
|
||||||
|
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.forward_context import get_forward_context, set_forward_context
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||||
|
from vllm.model_executor.model_loader import get_model
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.sequence import (CompletionSequenceGroupOutput, IntermediateTensors,
|
||||||
|
Logprob, SequenceGroupMetadata, SequenceOutput)
|
||||||
|
from vllm.worker.model_runner_base import (
|
||||||
|
ModelRunnerBase, ModelRunnerInputBase,
|
||||||
|
_add_attn_metadata_broadcastable_dict,
|
||||||
|
_init_attn_metadata_from_tensor_dict)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# Here we utilize the behavior that out-of-bound index is ignored.
|
||||||
|
# FIXME(woosuk): Find a more reliable way to prevent possible bugs.
|
||||||
|
_PAD_SLOT_ID = 1_000_000_000
|
||||||
|
# FIXME(woosuk): Temporarily disabled top-p sampling since it's too slow.
|
||||||
|
_ENABLE_TOP_P = False
|
||||||
|
# FIXME(woosuk): A temporary hack to support `n > 1`.
|
||||||
|
# This can significantly affect the performance if too large.
|
||||||
|
_MAX_NUM_SAMPLES = 128
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutionMode(enum.Enum):
|
||||||
|
PREFILL = enum.auto()
|
||||||
|
DECODE = enum.auto()
|
||||||
|
PREFIX_PREFILL = enum.auto()
|
||||||
|
|
||||||
|
def is_prefill(self) -> bool:
|
||||||
|
return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ModelInputForTPU(ModelRunnerInputBase):
|
||||||
|
token_ids: torch.Tensor
|
||||||
|
position_ids: torch.Tensor
|
||||||
|
attn_metadata: AttentionMetadata
|
||||||
|
input_lens: torch.Tensor
|
||||||
|
t: torch.Tensor
|
||||||
|
p: torch.Tensor
|
||||||
|
num_samples: int
|
||||||
|
n: List[int]
|
||||||
|
seq_groups: List[List[int]]
|
||||||
|
is_first_multi_step: bool = True
|
||||||
|
is_last_step: bool = True
|
||||||
|
virtual_engine: int = 0
|
||||||
|
async_callback: Optional[Callable] = None
|
||||||
|
|
||||||
|
def as_broadcastable_tensor_dict(
|
||||||
|
self) -> Dict[str, Union[int, torch.Tensor]]:
|
||||||
|
tensor_dict = {
|
||||||
|
"token_ids": self.token_ids,
|
||||||
|
"position_ids": self.position_ids,
|
||||||
|
"input_lens": self.input_lens,
|
||||||
|
"t": self.t,
|
||||||
|
"p": self.p,
|
||||||
|
"num_samples": self.num_samples,
|
||||||
|
"n": self.n,
|
||||||
|
"seq_groups": self.seq_groups,
|
||||||
|
"is_first_multi_step": self.is_first_multi_step,
|
||||||
|
"is_last_step": self.is_last_step,
|
||||||
|
"virtual_engine": self.virtual_engine,
|
||||||
|
}
|
||||||
|
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||||
|
return tensor_dict
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_broadcasted_tensor_dict(
|
||||||
|
cls: Type["ModelInputForTPU"],
|
||||||
|
tensor_dict: Dict[str, Any],
|
||||||
|
attn_backend: Optional["AttentionBackend"] = None,
|
||||||
|
) -> "ModelInputForTPU":
|
||||||
|
if attn_backend is not None:
|
||||||
|
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
||||||
|
attn_backend, tensor_dict)
|
||||||
|
return cls(**tensor_dict)
|
||||||
|
|
||||||
|
|
||||||
|
class TPUModelRunner(ModelRunnerBase[ModelInputForTPU]):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
is_driver_worker: bool = False,
|
||||||
|
):
|
||||||
|
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
|
||||||
|
self.is_driver_worker = is_driver_worker
|
||||||
|
|
||||||
|
self.block_size = self.cache_config.block_size
|
||||||
|
self.max_num_blocks_per_seq = (self.model_config.max_model_len //
|
||||||
|
self.block_size)
|
||||||
|
self.block_tables = np.zeros(
|
||||||
|
(self.scheduler_config.max_num_seqs, self.max_num_blocks_per_seq),
|
||||||
|
dtype=np.int32)
|
||||||
|
self.attn_backend = get_attn_backend(
|
||||||
|
self.model_config.get_head_size(),
|
||||||
|
self.model_config.dtype,
|
||||||
|
self.cache_config.cache_dtype,
|
||||||
|
self.block_size,
|
||||||
|
self.model_config.is_attention_free,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
self.cached_step_outputs: List[torch.Tensor] = []
|
||||||
|
|
||||||
|
smem_size = 512 * 1024
|
||||||
|
block_table_size = 4 * self.block_tables.size
|
||||||
|
if block_table_size >= smem_size:
|
||||||
|
logger.warning(
|
||||||
|
"The max_model_len (%d) is too large. This may degrade the "
|
||||||
|
"performance due to the insufficient smem size. Consider "
|
||||||
|
"setting --max-model-len to a smaller value, like %d.",
|
||||||
|
self.model_config.max_model_len,
|
||||||
|
self.model_config.max_model_len /
|
||||||
|
(block_table_size / smem_size))
|
||||||
|
|
||||||
|
def load_model(self) -> None:
|
||||||
|
self.device = self.device_config.device
|
||||||
|
|
||||||
|
# NOTE(woosuk): While the executor assigns the TP ranks to the worker
|
||||||
|
# process, the ranks can be different from the ranks internally assigned
|
||||||
|
# by the xm runtime. Therefore, there is a mismatch in the rank
|
||||||
|
# assignment between the gloo (cpu) runtime and the xm (tpu) runtime.
|
||||||
|
# This is not a problem in linear layers because all-reduce is
|
||||||
|
# rank-agnostic. However, it matters for all-gather as the ranks
|
||||||
|
# determine the order of concatenating the output tensors.
|
||||||
|
# As a workaround, we use the xm's rank assignment only when loading
|
||||||
|
# the embedding weights.
|
||||||
|
xm_tp_rank = xr.global_ordinal()
|
||||||
|
with patch(
|
||||||
|
"vllm.model_executor.layers.vocab_parallel_embedding."
|
||||||
|
"get_tensor_model_parallel_rank",
|
||||||
|
return_value=xm_tp_rank):
|
||||||
|
model = get_model(vllm_config=self.vllm_config)
|
||||||
|
model = model.eval()
|
||||||
|
xm.wait_device_ops()
|
||||||
|
model = ModelWrapper(model)
|
||||||
|
self.model = torch.compile(model,
|
||||||
|
backend="openxla",
|
||||||
|
fullgraph=True,
|
||||||
|
dynamic=False)
|
||||||
|
|
||||||
|
def get_model(self) -> nn.Module:
|
||||||
|
return self.model.model
|
||||||
|
|
||||||
|
def _dummy_run(
|
||||||
|
self,
|
||||||
|
batch_size: int,
|
||||||
|
seq_len: int,
|
||||||
|
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
exec_mode: ExecutionMode,
|
||||||
|
) -> None:
|
||||||
|
exec_mode = ExecutionMode(exec_mode)
|
||||||
|
if exec_mode.is_prefill():
|
||||||
|
seq_len = (seq_len + 15) // 16 * 16
|
||||||
|
token_ids = torch.zeros((batch_size, seq_len),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
|
position_ids = torch.zeros((batch_size, seq_len),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
|
slot_mapping = torch.zeros((batch_size, seq_len),
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=self.device)
|
||||||
|
input_lens = torch.ones((batch_size, ),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
|
if exec_mode == ExecutionMode.PREFILL:
|
||||||
|
attn_metadata = self.attn_backend.make_metadata(
|
||||||
|
num_prefills=batch_size,
|
||||||
|
num_prefill_tokens=batch_size * seq_len,
|
||||||
|
num_decode_tokens=0,
|
||||||
|
slot_mapping=slot_mapping,
|
||||||
|
multi_modal_placeholder_index_maps=None,
|
||||||
|
enable_kv_scales_calculation=False,
|
||||||
|
block_tables=None,
|
||||||
|
context_lens=None,
|
||||||
|
effective_query_lens=None,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
context_lens = torch.ones((batch_size, ),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
|
block_tables = torch.tensor(self.block_tables[:batch_size],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
|
effective_query_lens = torch.ones_like(context_lens)
|
||||||
|
attn_metadata = self.attn_backend.make_metadata(
|
||||||
|
num_prefills=batch_size,
|
||||||
|
num_prefill_tokens=batch_size * seq_len,
|
||||||
|
num_decode_tokens=0,
|
||||||
|
slot_mapping=slot_mapping,
|
||||||
|
multi_modal_placeholder_index_maps=None,
|
||||||
|
enable_kv_scales_calculation=False,
|
||||||
|
block_tables=block_tables,
|
||||||
|
context_lens=context_lens,
|
||||||
|
effective_query_lens=effective_query_lens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert seq_len == 1
|
||||||
|
token_ids = torch.zeros((batch_size, seq_len),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
|
position_ids = torch.zeros((batch_size, seq_len),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
|
slot_mapping = torch.zeros((batch_size, seq_len),
|
||||||
|
dtype=torch.int64,
|
||||||
|
device=self.device)
|
||||||
|
block_tables = torch.zeros(
|
||||||
|
(batch_size, self.max_num_blocks_per_seq),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
|
context_lens = torch.ones((batch_size, ),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
|
input_lens = torch.ones((batch_size, ),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device)
|
||||||
|
attn_metadata = self.attn_backend.make_metadata(
|
||||||
|
num_prefills=0,
|
||||||
|
num_prefill_tokens=0,
|
||||||
|
num_decode_tokens=batch_size * seq_len,
|
||||||
|
slot_mapping=slot_mapping,
|
||||||
|
multi_modal_placeholder_index_maps=None,
|
||||||
|
enable_kv_scales_calculation=False,
|
||||||
|
block_tables=block_tables,
|
||||||
|
context_lens=context_lens,
|
||||||
|
)
|
||||||
|
t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
|
||||||
|
p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device)
|
||||||
|
num_samples = _MAX_NUM_SAMPLES if exec_mode.is_prefill() else 1
|
||||||
|
|
||||||
|
# NOTE(woosuk): There are two stages of compilation: torch.compile and
|
||||||
|
# XLA compilation. Using `mark_dynamic` can reduce the torch.compile
|
||||||
|
# overhead by reusing the FX graph for different shapes.
|
||||||
|
# However, the XLA graph will still require static shapes and needs to
|
||||||
|
# be re-compiled for every different shapes. This overhead is inevitable
|
||||||
|
# in the first run, but can be skipped afterwards as we cache the XLA
|
||||||
|
# graphs in the disk (VLLM_XLA_CACHE_PATH).
|
||||||
|
if exec_mode.is_prefill():
|
||||||
|
# Prefll
|
||||||
|
torch._dynamo.mark_dynamic(token_ids, 1)
|
||||||
|
torch._dynamo.mark_dynamic(position_ids, 1)
|
||||||
|
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 1)
|
||||||
|
else:
|
||||||
|
# Decode
|
||||||
|
torch._dynamo.mark_dynamic(token_ids, 0)
|
||||||
|
torch._dynamo.mark_dynamic(position_ids, 0)
|
||||||
|
torch._dynamo.mark_dynamic(input_lens, 0)
|
||||||
|
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
|
||||||
|
torch._dynamo.mark_dynamic(attn_metadata.context_lens, 0)
|
||||||
|
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
|
||||||
|
torch._dynamo.mark_dynamic(t, 0)
|
||||||
|
torch._dynamo.mark_dynamic(p, 0)
|
||||||
|
# Dummy run.
|
||||||
|
with set_forward_context(attn_metadata, self.vllm_config, 0):
|
||||||
|
self.model(token_ids, position_ids, input_lens, t, p, num_samples,
|
||||||
|
kv_caches)
|
||||||
|
|
||||||
|
def warmup_model(
|
||||||
|
self,
|
||||||
|
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
) -> None:
|
||||||
|
# Prefill
|
||||||
|
logger.info("Compiling the model with different input shapes...")
|
||||||
|
start = time.time()
|
||||||
|
for batch_size in [1]:
|
||||||
|
seq_len = 16
|
||||||
|
while seq_len <= self.model_config.max_model_len:
|
||||||
|
self._dummy_run(batch_size,
|
||||||
|
seq_len,
|
||||||
|
kv_caches,
|
||||||
|
exec_mode=ExecutionMode.PREFILL)
|
||||||
|
xm.wait_device_ops()
|
||||||
|
logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len)
|
||||||
|
num_tokens = batch_size * seq_len
|
||||||
|
if num_tokens >= self.scheduler_config.max_num_batched_tokens:
|
||||||
|
break
|
||||||
|
seq_len = seq_len * 2
|
||||||
|
|
||||||
|
end = time.time()
|
||||||
|
logger.info("Compilation for prefill done in %.2f s.", end - start)
|
||||||
|
|
||||||
|
# Prefix prefill
|
||||||
|
if self.cache_config.enable_prefix_caching:
|
||||||
|
logger.info("Compiling the model with different input shapes for "
|
||||||
|
"prefix prefill...")
|
||||||
|
start = time.time()
|
||||||
|
for batch_size in [1]:
|
||||||
|
seq_len = 16
|
||||||
|
while seq_len <= self.model_config.max_model_len:
|
||||||
|
self._dummy_run(batch_size,
|
||||||
|
seq_len,
|
||||||
|
kv_caches,
|
||||||
|
exec_mode=ExecutionMode.PREFIX_PREFILL)
|
||||||
|
xm.wait_device_ops()
|
||||||
|
logger.info("batch_size: %d, seq_len: %d", batch_size,
|
||||||
|
seq_len)
|
||||||
|
num_tokens = batch_size * seq_len
|
||||||
|
if (num_tokens
|
||||||
|
>= self.scheduler_config.max_num_batched_tokens):
|
||||||
|
break
|
||||||
|
seq_len = seq_len * 2
|
||||||
|
end = time.time()
|
||||||
|
logger.info("Compilation for prefix prefill done in %.2f s.",
|
||||||
|
end - start)
|
||||||
|
|
||||||
|
# Decode
|
||||||
|
start = time.time()
|
||||||
|
seq_len = 1
|
||||||
|
batch_size = 8 # Must be in sync with _get_padded_batch_size()
|
||||||
|
while True:
|
||||||
|
self._dummy_run(batch_size,
|
||||||
|
seq_len,
|
||||||
|
kv_caches,
|
||||||
|
exec_mode=ExecutionMode.DECODE)
|
||||||
|
xm.wait_device_ops()
|
||||||
|
logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len)
|
||||||
|
|
||||||
|
if batch_size >= self.scheduler_config.max_num_seqs:
|
||||||
|
break
|
||||||
|
batch_size = batch_size + 16 if batch_size >= 16 else batch_size * 2
|
||||||
|
|
||||||
|
end = time.time()
|
||||||
|
logger.info("Compilation for decode done in %.2f s.", end - start)
|
||||||
|
|
||||||
|
def _prepare_prompt(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
|
||||||
|
assert len(seq_group_metadata_list) > 0
|
||||||
|
input_tokens: List[int] = []
|
||||||
|
input_positions: List[int] = []
|
||||||
|
prompt_lens: List[int] = []
|
||||||
|
context_lens: List[int] = []
|
||||||
|
slot_mapping: List[int] = []
|
||||||
|
|
||||||
|
for batch_idx, seq_group_metadata in enumerate(
|
||||||
|
seq_group_metadata_list):
|
||||||
|
assert seq_group_metadata.is_prompt
|
||||||
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||||
|
assert len(seq_ids) == 1
|
||||||
|
seq_id = seq_ids[0]
|
||||||
|
|
||||||
|
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||||
|
# Could include output tokens when a request is preempted.
|
||||||
|
prompt_tokens = seq_data.get_token_ids()
|
||||||
|
seq_len = len(prompt_tokens)
|
||||||
|
|
||||||
|
num_computed_blocks = len(seq_group_metadata.computed_block_nums)
|
||||||
|
num_computed_tokens = num_computed_blocks * self.block_size
|
||||||
|
if num_computed_tokens > 0:
|
||||||
|
prompt_tokens = prompt_tokens[num_computed_tokens:]
|
||||||
|
context_lens.append(seq_len)
|
||||||
|
else:
|
||||||
|
context_lens.append(0)
|
||||||
|
|
||||||
|
prompt_len = len(prompt_tokens)
|
||||||
|
prompt_lens.append(prompt_len)
|
||||||
|
|
||||||
|
input_tokens.extend(prompt_tokens)
|
||||||
|
input_positions.extend(range(num_computed_tokens, seq_len))
|
||||||
|
|
||||||
|
assert seq_group_metadata.block_tables is not None
|
||||||
|
block_table = seq_group_metadata.block_tables[seq_id]
|
||||||
|
for i in range(num_computed_tokens, seq_len):
|
||||||
|
block_number = block_table[i // self.block_size]
|
||||||
|
block_offset = i % self.block_size
|
||||||
|
slot = block_number * self.block_size + block_offset
|
||||||
|
slot_mapping.append(slot)
|
||||||
|
if num_computed_tokens > 0:
|
||||||
|
self.block_tables[batch_idx, :len(block_table)] = block_table
|
||||||
|
|
||||||
|
# Add paddings to EACH prompt to the smallest power of 2 that is
|
||||||
|
# greater than or equal to the prompt length.
|
||||||
|
# We pad the seq_len to reduce the compilation overhead.
|
||||||
|
# We execute each prompt individually (i.e., with batch_size 1)
|
||||||
|
# because the FlashAttention kernel does not support ragged inputs.
|
||||||
|
# TODO(woosuk): Use SplashAttention to support ragged inputs.
|
||||||
|
padded_prompt_len = _get_padded_prefill_len(prompt_len)
|
||||||
|
num_paddings = padded_prompt_len - prompt_len
|
||||||
|
input_tokens += [0] * num_paddings
|
||||||
|
input_positions += [0] * num_paddings
|
||||||
|
slot_mapping += [_PAD_SLOT_ID] * num_paddings
|
||||||
|
|
||||||
|
assert len(prompt_lens) > 0
|
||||||
|
num_prefills = len(prompt_lens)
|
||||||
|
input_tokens = torch.tensor(input_tokens,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu")
|
||||||
|
input_positions = torch.tensor(input_positions,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu")
|
||||||
|
slot_mapping = torch.tensor(slot_mapping,
|
||||||
|
dtype=torch.int64,
|
||||||
|
device="cpu")
|
||||||
|
prompt_lens = torch.tensor(prompt_lens,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu")
|
||||||
|
context_lens = torch.tensor(context_lens,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu")
|
||||||
|
block_tables = torch.tensor(self.block_tables[:num_prefills],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu")
|
||||||
|
attn_metadata = self.attn_backend.make_metadata(
|
||||||
|
num_prefills=num_prefills,
|
||||||
|
num_prefill_tokens=0, # NOTE: This is not used.
|
||||||
|
num_decode_tokens=0,
|
||||||
|
slot_mapping=slot_mapping,
|
||||||
|
multi_modal_placeholder_index_maps=None,
|
||||||
|
enable_kv_scales_calculation=False,
|
||||||
|
block_tables=block_tables,
|
||||||
|
context_lens=context_lens,
|
||||||
|
effective_query_lens=prompt_lens,
|
||||||
|
)
|
||||||
|
return input_tokens, input_positions, attn_metadata, prompt_lens
|
||||||
|
|
||||||
|
def _prepare_decode(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, torch.Tensor]:
|
||||||
|
assert len(seq_group_metadata_list) > 0
|
||||||
|
input_tokens: List[List[int]] = []
|
||||||
|
input_positions: List[List[int]] = []
|
||||||
|
slot_mapping: List[List[int]] = []
|
||||||
|
context_lens: List[int] = []
|
||||||
|
|
||||||
|
batch_idx = 0
|
||||||
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
|
assert not seq_group_metadata.is_prompt
|
||||||
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||||
|
for seq_id in seq_ids:
|
||||||
|
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||||
|
generation_token = seq_data.get_last_token_id()
|
||||||
|
input_tokens.append([generation_token])
|
||||||
|
|
||||||
|
seq_len = seq_data.get_len()
|
||||||
|
position = seq_len - 1
|
||||||
|
input_positions.append([position])
|
||||||
|
context_lens.append(seq_len)
|
||||||
|
|
||||||
|
assert seq_group_metadata.block_tables is not None
|
||||||
|
block_table = seq_group_metadata.block_tables[seq_id]
|
||||||
|
self.block_tables[batch_idx, :len(block_table)] = block_table
|
||||||
|
batch_idx += 1
|
||||||
|
|
||||||
|
block_number = block_table[position // self.block_size]
|
||||||
|
block_offset = position % self.block_size
|
||||||
|
slot = block_number * self.block_size + block_offset
|
||||||
|
slot_mapping.append([slot])
|
||||||
|
|
||||||
|
batch_size = _get_padded_batch_size(batch_idx)
|
||||||
|
num_paddings = batch_size - batch_idx
|
||||||
|
input_tokens = input_tokens + [[0]] * num_paddings
|
||||||
|
input_positions = input_positions + [[0]] * num_paddings
|
||||||
|
slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings
|
||||||
|
context_lens = context_lens + [0] * num_paddings
|
||||||
|
|
||||||
|
input_tokens = torch.tensor(input_tokens,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu")
|
||||||
|
input_positions = torch.tensor(input_positions,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu")
|
||||||
|
slot_mapping = torch.tensor(slot_mapping,
|
||||||
|
dtype=torch.int64,
|
||||||
|
device="cpu")
|
||||||
|
context_lens = torch.tensor(context_lens,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu")
|
||||||
|
block_tables = torch.tensor(self.block_tables[:batch_size],
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu")
|
||||||
|
input_lens = torch.tensor([1] * batch_size,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cpu")
|
||||||
|
attn_metadata = self.attn_backend.make_metadata(
|
||||||
|
num_prefills=0,
|
||||||
|
num_prefill_tokens=0,
|
||||||
|
num_decode_tokens=batch_size,
|
||||||
|
slot_mapping=slot_mapping,
|
||||||
|
multi_modal_placeholder_index_maps=None,
|
||||||
|
enable_kv_scales_calculation=False,
|
||||||
|
block_tables=block_tables,
|
||||||
|
context_lens=context_lens,
|
||||||
|
)
|
||||||
|
return input_tokens, input_positions, attn_metadata, input_lens
|
||||||
|
|
||||||
|
def _prepare_sample(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
padded_batch_size: int,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, List[int]]:
|
||||||
|
assert len(seq_group_metadata_list) > 0
|
||||||
|
t = []
|
||||||
|
p = []
|
||||||
|
n = []
|
||||||
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
|
sampling_params = seq_group_metadata.sampling_params
|
||||||
|
t.append(sampling_params.temperature)
|
||||||
|
if sampling_params.top_p != 1 and not _ENABLE_TOP_P:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Top-p sampling is currently disabled for the TPU backend "
|
||||||
|
"due to performance issues.")
|
||||||
|
p.append(sampling_params.top_p)
|
||||||
|
if sampling_params.top_k > 0:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Top-k sampling is currently disabled for the TPU backend "
|
||||||
|
"due to performance issues.")
|
||||||
|
if sampling_params.n > _MAX_NUM_SAMPLES:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU "
|
||||||
|
"backend.")
|
||||||
|
n.append(sampling_params.n)
|
||||||
|
if sampling_params.logprobs is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"logprobs is not currently supported by the TPU backend.")
|
||||||
|
if sampling_params.prompt_logprobs is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"prompt_logprobs is not currently supported by the TPU "
|
||||||
|
"backend.")
|
||||||
|
|
||||||
|
# Repeat the sampling params if the seq group has multiple seqs.
|
||||||
|
num_seqs = len(seq_group_metadata.seq_data)
|
||||||
|
t += [t[-1]] * (num_seqs - 1)
|
||||||
|
p += [p[-1]] * (num_seqs - 1)
|
||||||
|
n += [n[-1]] * (num_seqs - 1)
|
||||||
|
|
||||||
|
num_paddings = padded_batch_size - len(t)
|
||||||
|
t += [1.0] * num_paddings
|
||||||
|
p += [1.0] * num_paddings
|
||||||
|
|
||||||
|
t = torch.tensor(t, dtype=torch.float32, device="cpu")
|
||||||
|
p = torch.tensor(p, dtype=torch.float32, device="cpu")
|
||||||
|
return t, p, n
|
||||||
|
|
||||||
|
def prepare_model_input(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
virtual_engine: int = 0,
|
||||||
|
finished_requests_ids: Optional[List[str]] = None,
|
||||||
|
) -> ModelInputForTPU:
|
||||||
|
del finished_requests_ids # Unused.
|
||||||
|
assert virtual_engine == 0
|
||||||
|
assert len(seq_group_metadata_list) > 0
|
||||||
|
# NOTE: We assume that all sequences in the group are all prompts or
|
||||||
|
# all decodes.
|
||||||
|
is_prompt = seq_group_metadata_list[0].is_prompt
|
||||||
|
if is_prompt:
|
||||||
|
inputs = self._prepare_prompt(seq_group_metadata_list)
|
||||||
|
else:
|
||||||
|
inputs = self._prepare_decode(seq_group_metadata_list)
|
||||||
|
input_tokens, input_positions, attn_metadata, input_lens = inputs
|
||||||
|
padded_batch_size = input_tokens.shape[0]
|
||||||
|
t, p, n = self._prepare_sample(seq_group_metadata_list,
|
||||||
|
padded_batch_size)
|
||||||
|
num_samples = _MAX_NUM_SAMPLES if is_prompt else 1
|
||||||
|
|
||||||
|
seq_groups = [
|
||||||
|
list(metadata.seq_data.keys())
|
||||||
|
for metadata in seq_group_metadata_list
|
||||||
|
]
|
||||||
|
return ModelInputForTPU(input_tokens, input_positions, attn_metadata,
|
||||||
|
input_lens, t, p, num_samples, n, seq_groups)
|
||||||
|
|
||||||
|
def make_model_input_from_broadcasted_tensor_dict(
|
||||||
|
self, tensor_dict: Dict[str, Any]) -> ModelInputForTPU:
|
||||||
|
model_input = ModelInputForTPU.from_broadcasted_tensor_dict(
|
||||||
|
tensor_dict, attn_backend=self.attn_backend)
|
||||||
|
return model_input
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
model_input: ModelInputForTPU,
|
||||||
|
kv_caches: Optional[List[Any]],
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
num_steps: int = 1,
|
||||||
|
) -> List[SamplerOutput]:
|
||||||
|
assert intermediate_tensors is None
|
||||||
|
if not model_input.is_first_multi_step:
|
||||||
|
if not model_input.is_last_step:
|
||||||
|
return []
|
||||||
|
|
||||||
|
use_async_out_proc = model_input.async_callback is not None
|
||||||
|
sampler_outputs = []
|
||||||
|
num_outputs = len(self.cached_step_outputs)
|
||||||
|
for i in range(num_outputs):
|
||||||
|
next_token_ids = self.cached_step_outputs.pop(0)
|
||||||
|
next_token_ids = next_token_ids.cpu().tolist()
|
||||||
|
sampler_output = _make_decode_output(next_token_ids,
|
||||||
|
model_input.seq_groups)
|
||||||
|
sampler_outputs.append(sampler_output)
|
||||||
|
|
||||||
|
if i < num_outputs - 1 and use_async_out_proc:
|
||||||
|
assert model_input.async_callback is not None
|
||||||
|
ctx = model_input.async_callback.keywords[ # type: ignore
|
||||||
|
"ctx"]
|
||||||
|
ctx.append_output(
|
||||||
|
outputs=[sampler_output],
|
||||||
|
seq_group_metadata_list=ctx.seq_group_metadata_list,
|
||||||
|
scheduler_outputs=ctx.scheduler_outputs,
|
||||||
|
is_async=False,
|
||||||
|
is_last_step=False,
|
||||||
|
is_first_step_output=i == 0)
|
||||||
|
model_input.async_callback()
|
||||||
|
if use_async_out_proc:
|
||||||
|
return [sampler_outputs[-1]]
|
||||||
|
else:
|
||||||
|
return sampler_outputs
|
||||||
|
|
||||||
|
is_prompt = model_input.attn_metadata.num_prefills > 0
|
||||||
|
if is_prompt:
|
||||||
|
assert num_steps == 1
|
||||||
|
# NOTE(woosuk): Since the FlashAttention kernel does not support
|
||||||
|
# ragged inputs, we split the prompts into different batches and
|
||||||
|
# process them separately. This is a temporary hack that should be
|
||||||
|
# optimized by using SplashAttention.
|
||||||
|
orig_slot_mapping = model_input.attn_metadata.slot_mapping
|
||||||
|
orig_block_tables = model_input.attn_metadata.block_tables
|
||||||
|
orig_context_lens = model_input.attn_metadata.context_lens
|
||||||
|
orig_effective_query_lens = \
|
||||||
|
model_input.attn_metadata.effective_query_lens
|
||||||
|
batch_size = model_input.input_lens.shape[0]
|
||||||
|
start_idx = 0
|
||||||
|
next_token_ids = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
# Get the actual prefill_len.
|
||||||
|
prefill_len = model_input.input_lens[i:i + 1].item()
|
||||||
|
prefill_len = _get_padded_prefill_len(prefill_len)
|
||||||
|
end_idx = start_idx + prefill_len
|
||||||
|
|
||||||
|
token_ids = model_input.token_ids[None, start_idx:end_idx].to(
|
||||||
|
self.device)
|
||||||
|
position_ids = model_input.position_ids[None,
|
||||||
|
start_idx:end_idx].to(
|
||||||
|
self.device)
|
||||||
|
attn_metadata = model_input.attn_metadata
|
||||||
|
attn_metadata.num_prefills = 1
|
||||||
|
attn_metadata.slot_mapping = orig_slot_mapping[
|
||||||
|
None, start_idx:end_idx].to(self.device)
|
||||||
|
if orig_context_lens[i].item() > 0:
|
||||||
|
attn_metadata.context_lens = orig_context_lens[i:i + 1].to(
|
||||||
|
self.device)
|
||||||
|
attn_metadata.block_tables = orig_block_tables[
|
||||||
|
i].unsqueeze(0).to(self.device)
|
||||||
|
attn_metadata.effective_query_lens = \
|
||||||
|
orig_effective_query_lens[i:i + 1].to(self.device)
|
||||||
|
else:
|
||||||
|
attn_metadata.context_lens = None
|
||||||
|
attn_metadata.block_tables = None
|
||||||
|
attn_metadata.effective_query_lens = None
|
||||||
|
input_lens = model_input.input_lens[i:i + 1].to(self.device)
|
||||||
|
t = model_input.t[i:i + 1].to(self.device)
|
||||||
|
p = model_input.p[i:i + 1].to(self.device)
|
||||||
|
with set_forward_context(model_input.attn_metadata,
|
||||||
|
self.vllm_config,
|
||||||
|
model_input.virtual_engine):
|
||||||
|
output_token_ids = self.model(token_ids, position_ids,
|
||||||
|
input_lens, t, p,
|
||||||
|
model_input.num_samples,
|
||||||
|
kv_caches)
|
||||||
|
next_token_ids.append(output_token_ids[0])
|
||||||
|
start_idx = end_idx
|
||||||
|
|
||||||
|
if model_input.async_callback is not None:
|
||||||
|
model_input.async_callback()
|
||||||
|
# Retrieve the outputs to CPU.
|
||||||
|
next_token_ids = [
|
||||||
|
output_token_ids.cpu().tolist()
|
||||||
|
for output_token_ids in next_token_ids
|
||||||
|
]
|
||||||
|
|
||||||
|
# NOTE(woosuk): Minimal code to construct the sampler outputs.
|
||||||
|
# The TPU backend does not reuse the sampler, since the TPU backend
|
||||||
|
# does not support advanced sampling parameters such as logprobs.
|
||||||
|
zero_logprob = Logprob(0.0)
|
||||||
|
sampler_outputs = []
|
||||||
|
for i, seq_group in enumerate(model_input.seq_groups):
|
||||||
|
seq_ids = seq_group
|
||||||
|
assert len(seq_ids) == 1
|
||||||
|
seq_id = seq_ids[0]
|
||||||
|
seq_outputs = []
|
||||||
|
for j in range(model_input.n[i]):
|
||||||
|
next_token_id = next_token_ids[i][j]
|
||||||
|
seq_outputs.append(
|
||||||
|
SequenceOutput(seq_id, next_token_id,
|
||||||
|
{next_token_id: zero_logprob}))
|
||||||
|
sampler_outputs.append(
|
||||||
|
CompletionSequenceGroupOutput(seq_outputs, None))
|
||||||
|
return [SamplerOutput(sampler_outputs)]
|
||||||
|
else:
|
||||||
|
token_ids = model_input.token_ids.to(self.device)
|
||||||
|
position_ids = model_input.position_ids.to(self.device)
|
||||||
|
attn_metadata = model_input.attn_metadata
|
||||||
|
attn_metadata.slot_mapping = attn_metadata.slot_mapping.to(
|
||||||
|
self.device)
|
||||||
|
attn_metadata.block_tables = attn_metadata.block_tables.to(
|
||||||
|
self.device)
|
||||||
|
attn_metadata.context_lens = attn_metadata.context_lens.to(
|
||||||
|
self.device)
|
||||||
|
t = model_input.t.to(self.device)
|
||||||
|
p = model_input.p.to(self.device)
|
||||||
|
input_lens = model_input.input_lens.to(self.device)
|
||||||
|
for i in range(num_steps):
|
||||||
|
slot_mapping = attn_metadata.slot_mapping
|
||||||
|
with set_forward_context(model_input.attn_metadata,
|
||||||
|
self.vllm_config,
|
||||||
|
model_input.virtual_engine):
|
||||||
|
output_token_ids = self.model(token_ids, position_ids,
|
||||||
|
input_lens, t, p,
|
||||||
|
model_input.num_samples,
|
||||||
|
kv_caches)
|
||||||
|
self.cached_step_outputs.append(output_token_ids)
|
||||||
|
|
||||||
|
if i < num_steps - 1:
|
||||||
|
# Prepare the inputs for the next step.
|
||||||
|
token_ids = output_token_ids.unsqueeze(dim=1).int()
|
||||||
|
position_ids = position_ids + 1
|
||||||
|
attn_metadata.context_lens = attn_metadata.context_lens + 1
|
||||||
|
|
||||||
|
block_tables = attn_metadata.block_tables
|
||||||
|
block_number = block_tables.gather(
|
||||||
|
1,
|
||||||
|
position_ids.long() // self.block_size)
|
||||||
|
block_offset = position_ids % self.block_size
|
||||||
|
|
||||||
|
is_padding = slot_mapping == _PAD_SLOT_ID
|
||||||
|
slot_mapping = block_number * self.block_size + block_offset
|
||||||
|
slot_mapping = slot_mapping.long()
|
||||||
|
slot_mapping = torch.where(is_padding, _PAD_SLOT_ID,
|
||||||
|
slot_mapping)
|
||||||
|
attn_metadata.slot_mapping = slot_mapping
|
||||||
|
|
||||||
|
if model_input.async_callback is not None:
|
||||||
|
model_input.async_callback()
|
||||||
|
|
||||||
|
if num_steps > 1:
|
||||||
|
return []
|
||||||
|
# Retrieve the outputs to CPU.
|
||||||
|
next_token_ids = self.cached_step_outputs.pop(0)
|
||||||
|
next_token_ids = next_token_ids.cpu().tolist()
|
||||||
|
sampler_output = _make_decode_output(next_token_ids,
|
||||||
|
model_input.seq_groups)
|
||||||
|
return [sampler_output]
|
||||||
|
|
||||||
|
|
||||||
|
class ModelWrapper(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, model: nn.Module):
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
token_ids: torch.Tensor,
|
||||||
|
position_ids: torch.Tensor,
|
||||||
|
input_lens: torch.Tensor,
|
||||||
|
t: torch.Tensor,
|
||||||
|
p: torch.Tensor,
|
||||||
|
num_samples: int,
|
||||||
|
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Executes the forward pass of the model and samples the next token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token_ids: The input token IDs of shape [batch_size, seq_len].
|
||||||
|
position_ids: The input position IDs of shape [batch_size, seq_len].
|
||||||
|
input_lens: The actual input lengths of shape [batch_size].
|
||||||
|
t: The sampling temperature of shape [batch_size].
|
||||||
|
p: The top-p probability of shape [batch_size].
|
||||||
|
num_samples: Number of samples to draw from each logits vector.
|
||||||
|
kv_caches: The key and value caches. They can be None during the
|
||||||
|
memory profiling at initialization.
|
||||||
|
"""
|
||||||
|
batch_size, seq_len = token_ids.shape
|
||||||
|
# Calculate the positions to sample from.
|
||||||
|
start_indices = torch.arange(
|
||||||
|
batch_size, dtype=torch.int32, device=input_lens.device) * seq_len
|
||||||
|
logits_indices = start_indices + input_lens - 1
|
||||||
|
attn_metadata = get_forward_context().attn_metadata
|
||||||
|
|
||||||
|
# FIXME(woosuk): This is a temporary hack to avoid using the existing
|
||||||
|
# sampler and sampling metadata.
|
||||||
|
sampling_metadata = SamplingMetadata(
|
||||||
|
seq_groups=[],
|
||||||
|
selected_token_indices=logits_indices,
|
||||||
|
categorized_sample_indices={},
|
||||||
|
num_prompts=attn_metadata.num_prefills,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Skip this in memory profiling at initialization.
|
||||||
|
if kv_caches[0][0].numel() > 0:
|
||||||
|
# index_copy_(slot_mapping) only works when the inserted dimension
|
||||||
|
# is 0. However, the KV cache in the Pallas backend has the shape
|
||||||
|
# [num_kv_heads, num_blocks, block_size, head_size]. To make it
|
||||||
|
# work, we need to flatten the first three dimensions and modify
|
||||||
|
# the slot_mapping accordingly.
|
||||||
|
num_kv_heads, num_blocks, block_size, _ = kv_caches[0][0].shape
|
||||||
|
slot_mapping = attn_metadata.slot_mapping
|
||||||
|
slot_mapping = slot_mapping.flatten()
|
||||||
|
head_indices = torch.arange(0,
|
||||||
|
num_kv_heads,
|
||||||
|
device=slot_mapping.device,
|
||||||
|
dtype=slot_mapping.dtype)
|
||||||
|
head_indices *= block_size * num_blocks
|
||||||
|
slot_mapping = slot_mapping.repeat_interleave(num_kv_heads).view(
|
||||||
|
-1, num_kv_heads)
|
||||||
|
slot_mapping = slot_mapping + head_indices.view(1, -1)
|
||||||
|
slot_mapping = slot_mapping.flatten()
|
||||||
|
attn_metadata.slot_mapping = slot_mapping
|
||||||
|
|
||||||
|
hidden_states = self.model(token_ids, position_ids)
|
||||||
|
hidden_states = hidden_states.flatten(0, 1)
|
||||||
|
logits = self.model.compute_logits(hidden_states, sampling_metadata)
|
||||||
|
|
||||||
|
# Argmax sampling.
|
||||||
|
argmax_token_ids = torch.argmax(logits, dim=-1, keepdim=True)
|
||||||
|
argmax_token_ids = argmax_token_ids.repeat(1, num_samples)
|
||||||
|
|
||||||
|
# Zero temperature means greedy decoding. Avoid division by zero.
|
||||||
|
nonzero_t = torch.where(t != 0, t, 1.0)
|
||||||
|
logits = logits / nonzero_t.unsqueeze(dim=1)
|
||||||
|
if _ENABLE_TOP_P:
|
||||||
|
logits = _apply_top_p(logits, p.unsqueeze(dim=1))
|
||||||
|
|
||||||
|
# Random sampling.
|
||||||
|
probs = torch.softmax(logits, dim=-1, dtype=torch.float32)
|
||||||
|
sampled_token_ids = torch.multinomial(probs,
|
||||||
|
num_samples,
|
||||||
|
replacement=True)
|
||||||
|
if num_samples == 1:
|
||||||
|
argmax_token_ids = argmax_token_ids.squeeze(dim=-1)
|
||||||
|
sampled_token_ids = sampled_token_ids.squeeze(dim=-1)
|
||||||
|
next_token_ids = torch.where(t != 0, sampled_token_ids,
|
||||||
|
argmax_token_ids)
|
||||||
|
return next_token_ids
|
||||||
|
|
||||||
|
|
||||||
|
def _get_padded_prefill_len(x: int) -> int:
|
||||||
|
# NOTE(woosuk): The pallas FlashAttention kernel requires the sequence
|
||||||
|
# length to be a multiple of 16. We pad the prompt length to the nearest
|
||||||
|
# multiple of 16. This is also good for performance.
|
||||||
|
if x <= 16:
|
||||||
|
return 16
|
||||||
|
return 1 << (x - 1).bit_length()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_padded_batch_size(batch_size: int) -> int:
|
||||||
|
# The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16.
|
||||||
|
# To meet this requirement in the simplest way, we set the minimal batch
|
||||||
|
# size to 8.
|
||||||
|
if batch_size <= 8:
|
||||||
|
return 8
|
||||||
|
else:
|
||||||
|
return ((batch_size + 15) // 16) * 16
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_top_p(logits: torch.Tensor, p: torch.Tensor) -> torch.Tensor:
|
||||||
|
logits_sorted = torch.sort(logits, dim=-1, descending=True).values
|
||||||
|
sorted_cum_probs = torch.cumsum(logits_sorted.softmax(dim=-1), dim=-1)
|
||||||
|
cutoff_index = torch.sum(sorted_cum_probs < p, dim=-1, keepdim=True)
|
||||||
|
cutoff_logit = torch.gather(logits_sorted, -1, cutoff_index)
|
||||||
|
logits = logits.masked_fill_(logits < cutoff_logit, -float("inf"))
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
def _make_decode_output(
|
||||||
|
next_token_ids: List[int],
|
||||||
|
seq_groups: List[List[int]],
|
||||||
|
) -> SamplerOutput:
|
||||||
|
zero_logprob = Logprob(0.0)
|
||||||
|
sampler_outputs = []
|
||||||
|
batch_idx = 0
|
||||||
|
for seq_group in seq_groups:
|
||||||
|
seq_ids = seq_group
|
||||||
|
seq_outputs = []
|
||||||
|
for seq_id in seq_ids:
|
||||||
|
next_token_id = next_token_ids[batch_idx]
|
||||||
|
seq_outputs.append(
|
||||||
|
SequenceOutput(seq_id, next_token_id,
|
||||||
|
{next_token_id: zero_logprob}))
|
||||||
|
batch_idx += 1
|
||||||
|
sampler_outputs.append(CompletionSequenceGroupOutput(
|
||||||
|
seq_outputs, None))
|
||||||
|
return SamplerOutput(sampler_outputs)
|
|
@ -0,0 +1,337 @@
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch_xla.core.xla_model as xm
|
||||||
|
import torch_xla.debug.profiler as xp
|
||||||
|
import torch_xla.runtime as xr
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||||
|
init_distributed_environment)
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor import set_random_seed
|
||||||
|
from vllm.sequence import ExecuteModelRequest
|
||||||
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, bind_kv_cache, get_dtype_size
|
||||||
|
from vllm.worker.tpu_model_runner import ExecutionMode, TPUModelRunner
|
||||||
|
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase,
|
||||||
|
LoRANotSupportedWorkerBase, WorkerBase,
|
||||||
|
WorkerInput)
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TPUWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
local_rank: int,
|
||||||
|
rank: int,
|
||||||
|
distributed_init_method: str,
|
||||||
|
is_driver_worker: bool,
|
||||||
|
) -> None:
|
||||||
|
WorkerBase.__init__(self, vllm_config=vllm_config)
|
||||||
|
self.parallel_config.rank = rank
|
||||||
|
self.local_rank = local_rank
|
||||||
|
self.rank = rank
|
||||||
|
self.distributed_init_method = distributed_init_method
|
||||||
|
self.is_driver_worker = is_driver_worker
|
||||||
|
|
||||||
|
assert self.device_config.device_type == "tpu"
|
||||||
|
if self.cache_config.cache_dtype == "auto":
|
||||||
|
self.cache_dtype = self.model_config.dtype
|
||||||
|
else:
|
||||||
|
self.cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[
|
||||||
|
self.cache_config.cache_dtype]
|
||||||
|
|
||||||
|
self.model_runner: TPUModelRunner = TPUModelRunner(
|
||||||
|
vllm_config=vllm_config, is_driver_worker=is_driver_worker)
|
||||||
|
|
||||||
|
if self.model_config.seed is None:
|
||||||
|
self.model_config.seed = 0
|
||||||
|
|
||||||
|
if vllm_config.lora_config is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"The V0 TPU backend doesn't support LoRA serving")
|
||||||
|
|
||||||
|
def init_device(self) -> None:
|
||||||
|
os.environ["PJRT_DEVICE"] = "TPU"
|
||||||
|
torch.set_grad_enabled(False)
|
||||||
|
torch.set_default_dtype(self.model_config.dtype)
|
||||||
|
|
||||||
|
# NOTE(woosuk): This is just to initialize the TP group and broadcast
|
||||||
|
# the input objects on CPU. The all-reduce and all-gather ops on TPU
|
||||||
|
# are invoked by `xm.all_reduce` and `xm.all_gather` which use their
|
||||||
|
# own context.
|
||||||
|
init_distributed_environment(
|
||||||
|
world_size=self.parallel_config.world_size,
|
||||||
|
rank=self.rank,
|
||||||
|
local_rank=self.local_rank,
|
||||||
|
distributed_init_method=self.distributed_init_method,
|
||||||
|
backend="gloo",
|
||||||
|
)
|
||||||
|
ensure_model_parallel_initialized(
|
||||||
|
self.parallel_config.tensor_parallel_size,
|
||||||
|
self.parallel_config.pipeline_parallel_size)
|
||||||
|
|
||||||
|
# Device initialization should happen after initializing the distributed
|
||||||
|
# runtime.
|
||||||
|
self.device = xm.xla_device()
|
||||||
|
self.device_config.device = self.device
|
||||||
|
|
||||||
|
# Set random seed.
|
||||||
|
set_random_seed(self.model_config.seed)
|
||||||
|
xm.set_rng_state(self.model_config.seed, self.device)
|
||||||
|
|
||||||
|
# Increase the cache size limit, which is the maximum number of
|
||||||
|
# dynamo graphs that can be compiled.
|
||||||
|
# NOTE(woosuk): Usually, we compile 10-15 graphs for prefill and
|
||||||
|
# 30-40 graphs for decode. 128 is an arbitrary safe number.
|
||||||
|
torch._dynamo.config.cache_size_limit = 128
|
||||||
|
# Use persistent cache to avoid XLA recompilation.
|
||||||
|
# NOTE(woosuk): Set per-rank cache path since different ranks
|
||||||
|
# can have slightly different XLA graphs.
|
||||||
|
world_size = self.parallel_config.world_size
|
||||||
|
rank = xr.global_ordinal()
|
||||||
|
# The PyTorch/XLA compilation cache uses the Torch IR to generate keys.
|
||||||
|
# Consequently, changes in optimization flags, which affect compilation
|
||||||
|
# results, don't change the cache key. This can result in the wrong
|
||||||
|
# compilation being used. To prevent this, disabling the XLA compilation
|
||||||
|
# cache during development is recommended.We can disable it by
|
||||||
|
# `export VLLM_XLA_CACHE_PATH=`
|
||||||
|
if envs.VLLM_XLA_CACHE_PATH:
|
||||||
|
per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH,
|
||||||
|
f"tp{world_size}_rank{rank}")
|
||||||
|
xr.initialize_cache(per_rank_path, readonly=False)
|
||||||
|
|
||||||
|
self.profiler = None
|
||||||
|
if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
|
||||||
|
# For TPU, we can only have 1 active profiler session for 1 profiler
|
||||||
|
# server. So we only profile on rank0.
|
||||||
|
self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR
|
||||||
|
logger.info("Profiling enabled. Traces will be saved to: %s",
|
||||||
|
self.profile_dir)
|
||||||
|
self.profiler = xp.start_server(9012)
|
||||||
|
|
||||||
|
def start_profile(self):
|
||||||
|
if self.rank < 1:
|
||||||
|
if self.profiler is None:
|
||||||
|
raise RuntimeError("Profiler is not enabled.")
|
||||||
|
xp.start_trace(self.profile_dir)
|
||||||
|
|
||||||
|
def stop_profile(self):
|
||||||
|
if self.rank < 1:
|
||||||
|
if self.profiler is None:
|
||||||
|
raise RuntimeError("Profiler is not enabled.")
|
||||||
|
xp.stop_trace()
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
self.model_runner.load_model()
|
||||||
|
|
||||||
|
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||||
|
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||||
|
head_size = self.model_config.get_head_size()
|
||||||
|
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||||
|
|
||||||
|
# use an empty tensor instead of `None`` to force Dynamo to pass
|
||||||
|
# it by reference, rather by specializing on the value ``None``.
|
||||||
|
# the `dtype` argument does not matter, and we use `float32` as
|
||||||
|
# a placeholder (it has wide hardware support).
|
||||||
|
kv_caches = [(torch.tensor([], dtype=torch.float32,
|
||||||
|
device=self.device),
|
||||||
|
torch.tensor([], dtype=torch.float32,
|
||||||
|
device=self.device))
|
||||||
|
for _ in range(num_layers)]
|
||||||
|
bind_kv_cache(self.compilation_config.static_forward_context,
|
||||||
|
[kv_caches])
|
||||||
|
self.model_runner._dummy_run(
|
||||||
|
batch_size=1,
|
||||||
|
seq_len=self.scheduler_config.max_num_batched_tokens,
|
||||||
|
kv_caches=kv_caches,
|
||||||
|
exec_mode=ExecutionMode.PREFILL,
|
||||||
|
)
|
||||||
|
# Synchronize before measuring the memory usage.
|
||||||
|
xm.wait_device_ops()
|
||||||
|
|
||||||
|
# Get the maximum amount of memory used by the model weights and
|
||||||
|
# intermediate activations.
|
||||||
|
m = xm.get_memory_info(self.device)
|
||||||
|
total_memory_size = m["bytes_limit"]
|
||||||
|
profiled = m["peak_bytes_used"] # Weights + intermediate activations.
|
||||||
|
|
||||||
|
# Calculate the TPU KV cache size based on profiling.
|
||||||
|
usable_memory_size = int(total_memory_size *
|
||||||
|
self.cache_config.gpu_memory_utilization)
|
||||||
|
tpu_kv_cache_bytes = max(usable_memory_size - profiled, 0)
|
||||||
|
dtype_bytes = get_dtype_size(self.cache_dtype)
|
||||||
|
block_size_bytes = (dtype_bytes * self.cache_config.block_size *
|
||||||
|
num_layers * 2 * head_size * num_kv_heads)
|
||||||
|
num_tpu_blocks = tpu_kv_cache_bytes // block_size_bytes
|
||||||
|
num_tpu_blocks = (num_tpu_blocks // 8) * 8 # Round down to 8.
|
||||||
|
|
||||||
|
# Calculate the CPU KV cache size based on the config.
|
||||||
|
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
|
||||||
|
block_size_bytes)
|
||||||
|
num_cpu_blocks = (num_cpu_blocks // 8) * 8 # Round down to 8.
|
||||||
|
return num_tpu_blocks, num_cpu_blocks
|
||||||
|
|
||||||
|
def initialize_cache(
|
||||||
|
self,
|
||||||
|
num_gpu_blocks: int,
|
||||||
|
num_cpu_blocks: int,
|
||||||
|
) -> None:
|
||||||
|
self.cache_config.num_gpu_blocks = num_gpu_blocks
|
||||||
|
self.cache_config.num_cpu_blocks = num_cpu_blocks
|
||||||
|
self.block_size = self.cache_config.block_size
|
||||||
|
|
||||||
|
dtype = self.cache_dtype
|
||||||
|
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||||
|
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||||
|
head_size = self.model_config.get_head_size()
|
||||||
|
|
||||||
|
self.cpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
|
||||||
|
self.tpu_cache: List[Tuple[torch.Tensor, torch.Tensor]] = []
|
||||||
|
tpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
|
||||||
|
num_gpu_blocks, self.block_size, num_kv_heads, head_size)
|
||||||
|
cpu_cache_shape = self.model_runner.attn_backend.get_kv_cache_shape(
|
||||||
|
num_cpu_blocks, self.block_size, num_kv_heads, head_size)
|
||||||
|
for _ in range(num_layers):
|
||||||
|
tpu_k_cache = torch.zeros(tpu_cache_shape,
|
||||||
|
dtype=dtype,
|
||||||
|
device=self.device)
|
||||||
|
tpu_v_cache = torch.zeros_like(tpu_k_cache)
|
||||||
|
self.tpu_cache.append((tpu_k_cache, tpu_v_cache))
|
||||||
|
cpu_k_cache = torch.zeros(cpu_cache_shape,
|
||||||
|
dtype=dtype,
|
||||||
|
device="cpu")
|
||||||
|
cpu_v_cache = torch.zeros_like(cpu_k_cache)
|
||||||
|
self.cpu_cache.append((cpu_k_cache, cpu_v_cache))
|
||||||
|
bind_kv_cache(self.compilation_config.static_forward_context,
|
||||||
|
[self.tpu_cache])
|
||||||
|
self._warmup_model()
|
||||||
|
|
||||||
|
def _warmup_model(self) -> None:
|
||||||
|
# FIXME(woosuk): Here we are abusing `enforce_eager` which is defined
|
||||||
|
# for CUDA graphs. We should refactor this part.
|
||||||
|
if not self.model_config.enforce_eager:
|
||||||
|
# Warm up the model with all possible input shapes so that
|
||||||
|
# compilation never happens during the actual execution.
|
||||||
|
# This may take ~30 mins for the first run and ~20 mins for the
|
||||||
|
# subsequent runs.
|
||||||
|
# If `enforce_eager` is True, the ahead-of-time compilation is
|
||||||
|
# skipped and the compilation happens during the actual execution,
|
||||||
|
# which is bad for performance but useful for development.
|
||||||
|
self.model_runner.warmup_model(self.tpu_cache)
|
||||||
|
|
||||||
|
def get_cache_block_size_bytes(self) -> int:
|
||||||
|
head_size = self.model_config.get_head_size()
|
||||||
|
num_heads = self.model_config.get_num_kv_heads(self.parallel_config)
|
||||||
|
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||||
|
|
||||||
|
key_cache_block = self.cache_config.block_size * num_heads * head_size
|
||||||
|
value_cache_block = key_cache_block
|
||||||
|
total = num_layers * (key_cache_block + value_cache_block)
|
||||||
|
dtype_size = get_dtype_size(self.cache_dtype)
|
||||||
|
return dtype_size * total
|
||||||
|
|
||||||
|
@property
|
||||||
|
def do_metadata_broadcast(self) -> bool:
|
||||||
|
return self.parallel_config.tensor_parallel_size > 1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def kv_cache(self) -> Optional[List[List[torch.Tensor]]]:
|
||||||
|
# NOTE(woosuk): This assumes virtual_engine == 0, i.e., no pipeline
|
||||||
|
# parallelism.
|
||||||
|
return [self.tpu_cache]
|
||||||
|
|
||||||
|
def prepare_worker_input(
|
||||||
|
self,
|
||||||
|
execute_model_req: ExecuteModelRequest,
|
||||||
|
) -> WorkerInput:
|
||||||
|
virtual_engine = execute_model_req.virtual_engine
|
||||||
|
num_seq_groups = len(execute_model_req.seq_group_metadata_list)
|
||||||
|
blocks_to_swap_in = _make_src_to_dst(
|
||||||
|
execute_model_req.blocks_to_swap_in, "cpu", self.device)
|
||||||
|
blocks_to_swap_out = _make_src_to_dst(
|
||||||
|
execute_model_req.blocks_to_swap_out, self.device, "cpu")
|
||||||
|
blocks_to_copy = _make_src_to_dst(execute_model_req.blocks_to_copy,
|
||||||
|
self.device, self.device)
|
||||||
|
return WorkerInput(
|
||||||
|
num_seq_groups=num_seq_groups,
|
||||||
|
blocks_to_swap_in=blocks_to_swap_in,
|
||||||
|
blocks_to_swap_out=blocks_to_swap_out,
|
||||||
|
blocks_to_copy=blocks_to_copy,
|
||||||
|
virtual_engine=virtual_engine,
|
||||||
|
)
|
||||||
|
|
||||||
|
def execute_worker(self, worker_input: WorkerInput) -> None:
|
||||||
|
virtual_engine = worker_input.virtual_engine
|
||||||
|
assert virtual_engine == 0
|
||||||
|
attn_backend = self.model_runner.attn_backend
|
||||||
|
num_layers = self.model_config.get_num_layers(self.parallel_config)
|
||||||
|
|
||||||
|
# Issue cache operations.
|
||||||
|
if worker_input.blocks_to_swap_in is not None:
|
||||||
|
src_indices, dst_indices = worker_input.blocks_to_swap_in
|
||||||
|
if src_indices.numel() > 0:
|
||||||
|
# Swap from CPU to TPU.
|
||||||
|
for i in range(num_layers):
|
||||||
|
tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
|
||||||
|
cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
|
||||||
|
k = cpu_k_cache[:, src_indices].to(self.device)
|
||||||
|
v = cpu_v_cache[:, src_indices].to(self.device)
|
||||||
|
_insert_kv(k, v, dst_indices, tpu_k_cache, tpu_v_cache)
|
||||||
|
|
||||||
|
if worker_input.blocks_to_swap_out is not None:
|
||||||
|
src_indices, dst_indices = worker_input.blocks_to_swap_out
|
||||||
|
if src_indices.numel() > 0:
|
||||||
|
# Swap from TPU to CPU.
|
||||||
|
for i in range(num_layers):
|
||||||
|
tpu_k_cache, tpu_v_cache = self.tpu_cache[i]
|
||||||
|
cpu_k_cache, cpu_v_cache = self.cpu_cache[i]
|
||||||
|
cpu_k_cache[:, dst_indices] = tpu_k_cache[:, src_indices]
|
||||||
|
cpu_v_cache[:, dst_indices] = tpu_v_cache[:, src_indices]
|
||||||
|
|
||||||
|
if worker_input.blocks_to_copy is not None:
|
||||||
|
src_indices, dst_indices = worker_input.blocks_to_copy
|
||||||
|
if src_indices.numel() > 0:
|
||||||
|
attn_backend.copy_blocks(self.tpu_cache,
|
||||||
|
(src_indices, dst_indices))
|
||||||
|
|
||||||
|
|
||||||
|
def _make_src_to_dst(
|
||||||
|
mapping: List[Tuple[int, int]],
|
||||||
|
src_device: Union[torch.device, str],
|
||||||
|
dst_device: Union[torch.device, str],
|
||||||
|
) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
|
||||||
|
if not mapping:
|
||||||
|
return None
|
||||||
|
|
||||||
|
src_indices = [i for i, _ in mapping]
|
||||||
|
dst_indices = [i for _, i in mapping]
|
||||||
|
src_indices = torch.tensor(src_indices,
|
||||||
|
device=src_device,
|
||||||
|
dtype=torch.int64)
|
||||||
|
dst_indices = torch.tensor(dst_indices,
|
||||||
|
device=dst_device,
|
||||||
|
dtype=torch.int64)
|
||||||
|
return src_indices, dst_indices
|
||||||
|
|
||||||
|
|
||||||
|
@torch.compile(backend="openxla")
|
||||||
|
def _insert_kv(
|
||||||
|
k: torch.Tensor,
|
||||||
|
v: torch.Tensor,
|
||||||
|
indices: torch.Tensor,
|
||||||
|
tpu_k_cache: torch.Tensor,
|
||||||
|
tpu_v_cache: torch.Tensor,
|
||||||
|
) -> None:
|
||||||
|
torch.ops.xla.dynamo_set_buffer_donor_(tpu_k_cache, True)
|
||||||
|
torch.ops.xla.dynamo_set_buffer_donor_(tpu_v_cache, True)
|
||||||
|
tpu_k_cache[:, indices] = k
|
||||||
|
tpu_v_cache[:, indices] = v
|
|
@ -0,0 +1,606 @@
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import dataclasses
|
||||||
|
import time
|
||||||
|
import weakref
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
|
||||||
|
Type, TypeVar)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm.attention import get_attn_backend
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed import get_pp_group
|
||||||
|
from vllm.forward_context import set_forward_context
|
||||||
|
from vllm.inputs import INPUT_REGISTRY, InputRegistry
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor import SamplingMetadataCache
|
||||||
|
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||||
|
from vllm.model_executor.model_loader import get_model
|
||||||
|
from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs,
|
||||||
|
MultiModalKwargs, MultiModalPlaceholderMap,
|
||||||
|
MultiModalRegistry)
|
||||||
|
from vllm.sampling_params import SamplingParams
|
||||||
|
from vllm.sequence import IntermediateTensors, SequenceGroupMetadata
|
||||||
|
from vllm.utils import DeviceMemoryProfiler, GiB_bytes, make_tensor_with_pad
|
||||||
|
from vllm.worker.model_runner import AttentionMetadata, SamplingMetadata
|
||||||
|
from vllm.worker.model_runner_base import (
|
||||||
|
ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase,
|
||||||
|
_add_attn_metadata_broadcastable_dict,
|
||||||
|
_add_sampling_metadata_broadcastable_dict,
|
||||||
|
_init_attn_metadata_from_tensor_dict,
|
||||||
|
_init_sampling_metadata_from_tensor_dict)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
_PAD_SLOT_ID = -1
|
||||||
|
|
||||||
|
TModelInputForXPU = TypeVar('TModelInputForXPU', bound="ModelInputForXPU")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ModelInputForXPU(ModelRunnerInputBase):
|
||||||
|
"""
|
||||||
|
Used by the NeuronModelRunner.
|
||||||
|
"""
|
||||||
|
input_tokens: Optional[torch.Tensor] = None
|
||||||
|
input_positions: Optional[torch.Tensor] = None
|
||||||
|
attn_metadata: Optional["AttentionMetadata"] = None
|
||||||
|
multi_modal_kwargs: Optional[BatchedTensorInputs] = None
|
||||||
|
virtual_engine: Optional[int] = None
|
||||||
|
seq_lens: Optional[List[int]] = None
|
||||||
|
query_lens: Optional[List[int]] = None
|
||||||
|
async_callback: Optional[Callable] = None
|
||||||
|
|
||||||
|
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||||
|
tensor_dict = {
|
||||||
|
"input_tokens": self.input_tokens,
|
||||||
|
"input_positions": self.input_positions,
|
||||||
|
}
|
||||||
|
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||||
|
|
||||||
|
return tensor_dict
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_broadcasted_tensor_dict(
|
||||||
|
cls: Type[TModelInputForXPU],
|
||||||
|
tensor_dict: Dict[str, Any],
|
||||||
|
attn_backend: Optional["AttentionBackend"] = None,
|
||||||
|
) -> TModelInputForXPU:
|
||||||
|
if attn_backend is not None:
|
||||||
|
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
||||||
|
attn_backend, tensor_dict)
|
||||||
|
return cls(**tensor_dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ModelInputForXPUWithSamplingMetadata(ModelInputForXPU):
|
||||||
|
"""
|
||||||
|
Used by the ModelRunner.
|
||||||
|
"""
|
||||||
|
sampling_metadata: Optional["SamplingMetadata"] = None
|
||||||
|
|
||||||
|
def as_broadcastable_tensor_dict(self) -> Dict[str, Any]:
|
||||||
|
tensor_dict = {
|
||||||
|
"input_tokens": self.input_tokens,
|
||||||
|
"input_positions": self.input_positions,
|
||||||
|
}
|
||||||
|
_add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata)
|
||||||
|
_add_sampling_metadata_broadcastable_dict(tensor_dict,
|
||||||
|
self.sampling_metadata)
|
||||||
|
return tensor_dict
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_broadcasted_tensor_dict(
|
||||||
|
cls,
|
||||||
|
tensor_dict: Dict[str, Any],
|
||||||
|
attn_backend: Optional["AttentionBackend"] = None,
|
||||||
|
) -> "ModelInputForXPUWithSamplingMetadata":
|
||||||
|
tensor_dict = _init_sampling_metadata_from_tensor_dict(tensor_dict)
|
||||||
|
if attn_backend is not None:
|
||||||
|
tensor_dict = _init_attn_metadata_from_tensor_dict(
|
||||||
|
attn_backend, tensor_dict)
|
||||||
|
return cls(**tensor_dict)
|
||||||
|
|
||||||
|
|
||||||
|
class ModelInputForXPUBuilder(ModelRunnerInputBuilderBase[ModelInputForXPU]):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
runner: "XPUModelRunner",
|
||||||
|
finished_requests_ids: Optional[List[str]] = None) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.runner = runner
|
||||||
|
self.model_input_cls = self.runner._model_input_cls
|
||||||
|
self.attn_backend = self.runner.attn_backend
|
||||||
|
self.sliding_window = self.runner.sliding_window
|
||||||
|
self.block_size = self.runner.block_size
|
||||||
|
self.device = self.runner.device
|
||||||
|
|
||||||
|
def prepare(self,
|
||||||
|
finished_requests_ids: Optional[List[str]] = None) -> None:
|
||||||
|
self.seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||||
|
|
||||||
|
def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata):
|
||||||
|
self.seq_group_metadata_list.append(seq_group_metadata)
|
||||||
|
|
||||||
|
def build(self) -> ModelInputForXPU:
|
||||||
|
is_prompt = self.seq_group_metadata_list[0].is_prompt
|
||||||
|
# Prepare input tensors.
|
||||||
|
if is_prompt:
|
||||||
|
(input_tokens, input_positions, attn_metadata, seq_lens,
|
||||||
|
multi_modal_kwargs) = self._prepare_prompt(
|
||||||
|
self.seq_group_metadata_list)
|
||||||
|
else:
|
||||||
|
(input_tokens, input_positions,
|
||||||
|
attn_metadata) = self._prepare_decode(
|
||||||
|
self.seq_group_metadata_list)
|
||||||
|
seq_lens = None
|
||||||
|
multi_modal_kwargs = None
|
||||||
|
|
||||||
|
return self.model_input_cls(
|
||||||
|
input_tokens=input_tokens,
|
||||||
|
input_positions=input_positions,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
multi_modal_kwargs=multi_modal_kwargs,
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
query_lens=seq_lens,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _prepare_prompt(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int],
|
||||||
|
BatchedTensorInputs]:
|
||||||
|
assert len(seq_group_metadata_list) > 0
|
||||||
|
input_tokens: List[int] = []
|
||||||
|
input_positions: List[int] = []
|
||||||
|
slot_mapping: List[int] = []
|
||||||
|
seq_lens: List[int] = []
|
||||||
|
multi_modal_kwargs_list: List[MultiModalKwargs] = []
|
||||||
|
multi_modal_placeholder_maps: Dict[
|
||||||
|
str,
|
||||||
|
MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
|
||||||
|
|
||||||
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
|
assert seq_group_metadata.is_prompt
|
||||||
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||||
|
assert len(seq_ids) == 1
|
||||||
|
seq_id = seq_ids[0]
|
||||||
|
|
||||||
|
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||||
|
prompt_tokens = seq_data.get_token_ids()
|
||||||
|
computed_len = seq_data.get_num_computed_tokens()
|
||||||
|
seq_len = len(prompt_tokens)
|
||||||
|
|
||||||
|
seq_lens.append(seq_len) # Prompt token num
|
||||||
|
input_tokens.extend(prompt_tokens) # Token ids
|
||||||
|
|
||||||
|
# Token position ids
|
||||||
|
# NOTE(woosuk): Here we assume that the first token in the prompt
|
||||||
|
# is always the first token in the sequence.
|
||||||
|
positions_range = range(computed_len, seq_len)
|
||||||
|
input_positions.extend(list(positions_range))
|
||||||
|
|
||||||
|
if seq_group_metadata.multi_modal_data:
|
||||||
|
# NOTE: mm_kwargs only includes the subset of multi-modal items
|
||||||
|
# that intersect with the current prefill positions.
|
||||||
|
mm_kwargs, placeholder_maps = MultiModalPlaceholderMap \
|
||||||
|
.from_seq_group(seq_group_metadata, positions_range)
|
||||||
|
|
||||||
|
multi_modal_kwargs_list.append(mm_kwargs)
|
||||||
|
|
||||||
|
for modality, placeholder_map in placeholder_maps.items():
|
||||||
|
multi_modal_placeholder_maps[modality].extend(
|
||||||
|
placeholder_map)
|
||||||
|
|
||||||
|
if seq_group_metadata.block_tables is None:
|
||||||
|
# During memory profiling, the block tables are not initialized
|
||||||
|
# yet. In this case, we just use a dummy slot mapping.
|
||||||
|
slot_mapping.extend([_PAD_SLOT_ID] * seq_len)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Compute the slot mapping.
|
||||||
|
block_table = seq_group_metadata.block_tables[seq_id]
|
||||||
|
# Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID,
|
||||||
|
# where start_idx is max(0, seq_len - sliding_window).
|
||||||
|
# For example, if the prompt len is 10, sliding window is 8, and
|
||||||
|
# block size is 4, the first two tokens are masked and the slot
|
||||||
|
# mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1].
|
||||||
|
start_idx = 0
|
||||||
|
if self.sliding_window is not None:
|
||||||
|
start_idx = max(0, seq_len - self.sliding_window)
|
||||||
|
|
||||||
|
for i in range(computed_len, seq_len):
|
||||||
|
if i < start_idx:
|
||||||
|
slot_mapping.append(_PAD_SLOT_ID)
|
||||||
|
continue
|
||||||
|
|
||||||
|
block_number = block_table[i //
|
||||||
|
self.block_size] # type: ignore
|
||||||
|
block_offset = i % self.block_size # type: ignore
|
||||||
|
slot = block_number * self.block_size + block_offset
|
||||||
|
slot_mapping.append(slot)
|
||||||
|
|
||||||
|
num_prompt_tokens = len(input_tokens)
|
||||||
|
|
||||||
|
input_tokens = torch.tensor(input_tokens,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.device) # type: ignore
|
||||||
|
input_positions = torch.tensor(input_positions,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.device) # type: ignore
|
||||||
|
slot_mapping = torch.tensor(slot_mapping,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.device) # type: ignore
|
||||||
|
placeholder_index_maps = {
|
||||||
|
modality: placeholder_map.index_map()
|
||||||
|
for modality, placeholder_map in
|
||||||
|
multi_modal_placeholder_maps.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
max_seqlen = max(seq_lens)
|
||||||
|
tmp = [0]
|
||||||
|
tmp.extend(seq_lens)
|
||||||
|
seqlen = torch.tensor(tmp)
|
||||||
|
seqlen_q = torch.cumsum(seqlen, dim=0).to(device=self.device)
|
||||||
|
|
||||||
|
attn_metadata = self.attn_backend.make_metadata(
|
||||||
|
is_prompt=True,
|
||||||
|
slot_mapping=slot_mapping,
|
||||||
|
multi_modal_placeholder_index_maps=placeholder_index_maps,
|
||||||
|
enable_kv_scales_calculation=False,
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
seqlen_q=seqlen_q,
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
seq_lens_tensor=torch.tensor([]),
|
||||||
|
max_decode_seq_len=0,
|
||||||
|
num_prefills=len(seq_lens),
|
||||||
|
num_prefill_tokens=num_prompt_tokens,
|
||||||
|
num_decode_tokens=0,
|
||||||
|
block_tables=torch.tensor([], device=self.device, dtype=torch.int),
|
||||||
|
)
|
||||||
|
|
||||||
|
multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list)
|
||||||
|
|
||||||
|
return (input_tokens, input_positions, attn_metadata, seq_lens,
|
||||||
|
multi_modal_kwargs)
|
||||||
|
|
||||||
|
def _prepare_decode(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]:
|
||||||
|
assert len(seq_group_metadata_list) > 0
|
||||||
|
input_tokens: List[int] = []
|
||||||
|
input_positions: List[int] = []
|
||||||
|
slot_mapping: List[int] = []
|
||||||
|
seq_lens: List[int] = []
|
||||||
|
block_tables: List[List[int]] = []
|
||||||
|
|
||||||
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
|
assert not seq_group_metadata.is_prompt
|
||||||
|
assert seq_group_metadata.token_chunk_size == 1
|
||||||
|
|
||||||
|
seq_ids = list(seq_group_metadata.seq_data.keys())
|
||||||
|
|
||||||
|
for seq_id in seq_ids:
|
||||||
|
seq_data = seq_group_metadata.seq_data[seq_id]
|
||||||
|
generation_token = seq_data.get_last_token_id()
|
||||||
|
input_tokens.append(generation_token)
|
||||||
|
|
||||||
|
seq_len = seq_data.get_len()
|
||||||
|
position = seq_len - 1
|
||||||
|
input_positions.append(position)
|
||||||
|
|
||||||
|
seq_len = seq_len if self.sliding_window is None else min(
|
||||||
|
seq_len, self.sliding_window)
|
||||||
|
seq_lens.append(seq_len)
|
||||||
|
|
||||||
|
block_table = seq_group_metadata.block_tables[seq_id]
|
||||||
|
block_number = block_table[position // self.block_size]
|
||||||
|
block_offset = position % self.block_size
|
||||||
|
slot = block_number * self.block_size + block_offset
|
||||||
|
slot_mapping.append(slot)
|
||||||
|
|
||||||
|
if self.sliding_window is not None:
|
||||||
|
sliding_window_blocks = (self.sliding_window //
|
||||||
|
self.block_size)
|
||||||
|
block_table = block_table[-sliding_window_blocks:]
|
||||||
|
block_tables.append(block_table)
|
||||||
|
|
||||||
|
max_decode_seq_len = max(seq_lens)
|
||||||
|
|
||||||
|
input_tokens = torch.tensor(input_tokens,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.device)
|
||||||
|
input_positions = torch.tensor(input_positions,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.device)
|
||||||
|
slot_mapping = torch.tensor(slot_mapping,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=self.device)
|
||||||
|
seq_lens_tensor = torch.tensor(seq_lens,
|
||||||
|
dtype=torch.int,
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
|
block_tables = make_tensor_with_pad(
|
||||||
|
block_tables,
|
||||||
|
pad=0,
|
||||||
|
dtype=torch.int,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_metadata = self.attn_backend.make_metadata(
|
||||||
|
is_prompt=False,
|
||||||
|
slot_mapping=slot_mapping,
|
||||||
|
multi_modal_placeholder_index_maps=None,
|
||||||
|
enable_kv_scales_calculation=False,
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
seqlen_q=torch.tensor([]),
|
||||||
|
max_seqlen=0,
|
||||||
|
seq_lens_tensor=seq_lens_tensor,
|
||||||
|
max_decode_seq_len=max_decode_seq_len,
|
||||||
|
num_prefill_tokens=0,
|
||||||
|
num_decode_tokens=len(input_tokens),
|
||||||
|
num_prefills=0,
|
||||||
|
block_tables=block_tables,
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
input_tokens,
|
||||||
|
input_positions,
|
||||||
|
attn_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class XPUModelRunner(ModelRunnerBase[ModelInputForXPUWithSamplingMetadata]):
|
||||||
|
_model_input_cls: Type[ModelInputForXPUWithSamplingMetadata] = (
|
||||||
|
ModelInputForXPUWithSamplingMetadata)
|
||||||
|
_builder_cls: Type[ModelInputForXPUBuilder] = ModelInputForXPUBuilder
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
kv_cache_dtype: Optional[str] = "auto",
|
||||||
|
is_driver_worker: bool = False,
|
||||||
|
return_hidden_states: bool = False,
|
||||||
|
input_registry: InputRegistry = INPUT_REGISTRY,
|
||||||
|
mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
|
||||||
|
):
|
||||||
|
|
||||||
|
ModelRunnerBase.__init__(self, vllm_config=vllm_config)
|
||||||
|
model_config = self.model_config
|
||||||
|
cache_config = self.cache_config
|
||||||
|
self.is_driver_worker = is_driver_worker
|
||||||
|
self.return_hidden_states = return_hidden_states
|
||||||
|
|
||||||
|
self.device = self.device_config.device
|
||||||
|
|
||||||
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
|
self.sliding_window = model_config.get_sliding_window()
|
||||||
|
self.block_size = cache_config.block_size
|
||||||
|
|
||||||
|
self.attn_backend = get_attn_backend(
|
||||||
|
self.model_config.get_head_size(),
|
||||||
|
self.model_config.dtype,
|
||||||
|
self.kv_cache_dtype,
|
||||||
|
self.block_size,
|
||||||
|
self.model_config.is_attention_free,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Multi-modal data support
|
||||||
|
self.input_registry = input_registry
|
||||||
|
self.mm_registry = mm_registry
|
||||||
|
|
||||||
|
# Lazy initialization.
|
||||||
|
self.model: nn.Module # Set after init_Model
|
||||||
|
self.sampler = get_sampler()
|
||||||
|
|
||||||
|
self.sampling_metadata_cache: SamplingMetadataCache = \
|
||||||
|
SamplingMetadataCache() \
|
||||||
|
if self.parallel_config.pipeline_parallel_size == 1 else None
|
||||||
|
|
||||||
|
self.builder = self._builder_cls(weakref.proxy(self))
|
||||||
|
|
||||||
|
def load_model(self) -> None:
|
||||||
|
with DeviceMemoryProfiler() as m:
|
||||||
|
self.model = get_model(vllm_config=self.vllm_config)
|
||||||
|
|
||||||
|
self.model_memory_usage = m.consumed_memory
|
||||||
|
logger.info("Loading model weights took %.4f GiB",
|
||||||
|
self.model_memory_usage / GiB_bytes)
|
||||||
|
|
||||||
|
def get_model(self) -> nn.Module:
|
||||||
|
return self.model
|
||||||
|
|
||||||
|
@property
|
||||||
|
def vocab_size(self) -> int:
|
||||||
|
return self.model_config.get_vocab_size()
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def profile_run(self) -> None:
|
||||||
|
# Enable top-k sampling to reflect the accurate memory usage.
|
||||||
|
sampling_params = SamplingParams(top_p=0.99, top_k=self.vocab_size - 1)
|
||||||
|
max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens
|
||||||
|
max_num_seqs = self.scheduler_config.max_num_seqs
|
||||||
|
|
||||||
|
# Profile memory usage with max_num_sequences sequences and the total
|
||||||
|
# number of tokens equal to max_num_batched_tokens.
|
||||||
|
seqs: List[SequenceGroupMetadata] = []
|
||||||
|
# Additional GPU memory may be needed for multi-modal encoding, which
|
||||||
|
# needs to be accounted for when calculating the GPU blocks for
|
||||||
|
# vLLM blocker manager.
|
||||||
|
# To exercise the worst scenario for GPU memory consumption,
|
||||||
|
# the number of seqs (batch_size) is chosen to maximize the number
|
||||||
|
# of images processed.
|
||||||
|
max_mm_tokens = self.mm_registry.get_max_multimodal_tokens(
|
||||||
|
self.model_config)
|
||||||
|
if max_mm_tokens > 0:
|
||||||
|
max_num_seqs_orig = max_num_seqs
|
||||||
|
max_num_seqs = min(max_num_seqs,
|
||||||
|
max_num_batched_tokens // max_mm_tokens)
|
||||||
|
if max_num_seqs < 1:
|
||||||
|
expr = (f"min({max_num_seqs_orig}, "
|
||||||
|
f"{max_num_batched_tokens} // {max_mm_tokens})")
|
||||||
|
logger.warning(
|
||||||
|
"Computed max_num_seqs (%s) to be less than 1. "
|
||||||
|
"Setting it to the minimum value of 1.", expr)
|
||||||
|
max_num_seqs = 1
|
||||||
|
|
||||||
|
batch_size = 0
|
||||||
|
for group_id in range(max_num_seqs):
|
||||||
|
seq_len = (max_num_batched_tokens // max_num_seqs +
|
||||||
|
(group_id < max_num_batched_tokens % max_num_seqs))
|
||||||
|
batch_size += seq_len
|
||||||
|
|
||||||
|
dummy_data = self.input_registry \
|
||||||
|
.dummy_data_for_profiling(self.model_config,
|
||||||
|
seq_len,
|
||||||
|
self.mm_registry)
|
||||||
|
|
||||||
|
seq = SequenceGroupMetadata(
|
||||||
|
request_id=str(group_id),
|
||||||
|
is_prompt=True,
|
||||||
|
seq_data={group_id: dummy_data.seq_data},
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
block_tables=None,
|
||||||
|
lora_request=None,
|
||||||
|
multi_modal_data=dummy_data.multi_modal_data,
|
||||||
|
multi_modal_placeholders=dummy_data.multi_modal_placeholders)
|
||||||
|
seqs.append(seq)
|
||||||
|
|
||||||
|
finished_requests_ids = [seq.request_id for seq in seqs]
|
||||||
|
model_input = self.prepare_model_input(
|
||||||
|
seqs, finished_requests_ids=finished_requests_ids)
|
||||||
|
intermediate_tensors = None
|
||||||
|
if not get_pp_group().is_first_rank:
|
||||||
|
intermediate_tensors = self.model.make_empty_intermediate_tensors(
|
||||||
|
batch_size=batch_size,
|
||||||
|
dtype=self.model_config.dtype,
|
||||||
|
device=self.device)
|
||||||
|
self.execute_model(model_input, None, intermediate_tensors)
|
||||||
|
torch.xpu.synchronize()
|
||||||
|
return
|
||||||
|
|
||||||
|
def make_model_input_from_broadcasted_tensor_dict(
|
||||||
|
self,
|
||||||
|
tensor_dict: Dict[str,
|
||||||
|
Any]) -> ModelInputForXPUWithSamplingMetadata:
|
||||||
|
return (
|
||||||
|
ModelInputForXPUWithSamplingMetadata.from_broadcasted_tensor_dict(
|
||||||
|
tensor_dict,
|
||||||
|
attn_backend=self.attn_backend,
|
||||||
|
))
|
||||||
|
|
||||||
|
def _prepare_model_input_tensors(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
finished_requests_ids: Optional[List[str]] = None
|
||||||
|
) -> ModelInputForXPUWithSamplingMetadata:
|
||||||
|
"""Helper method to prepare the model input based on a given sequence
|
||||||
|
group. Prepares metadata needed for the base model forward pass but not
|
||||||
|
metadata for possible additional steps, e.g., sampling.
|
||||||
|
|
||||||
|
"""
|
||||||
|
builder = self.builder
|
||||||
|
builder.prepare(finished_requests_ids)
|
||||||
|
for seq_group_metadata in seq_group_metadata_list:
|
||||||
|
builder.add_seq_group(seq_group_metadata)
|
||||||
|
|
||||||
|
return builder.build() # type: ignore
|
||||||
|
|
||||||
|
def prepare_model_input(
|
||||||
|
self,
|
||||||
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
virtual_engine: int = 0,
|
||||||
|
finished_requests_ids: Optional[List[str]] = None
|
||||||
|
) -> ModelInputForXPUWithSamplingMetadata:
|
||||||
|
"""Prepare the model input based on a given sequence group, including
|
||||||
|
metadata for the sampling step.
|
||||||
|
|
||||||
|
"""
|
||||||
|
model_input = self._prepare_model_input_tensors(
|
||||||
|
seq_group_metadata_list, finished_requests_ids)
|
||||||
|
# Sampling metadata is only required for the final pp group
|
||||||
|
generators = self.get_generators(finished_requests_ids)
|
||||||
|
sampling_metadata = SamplingMetadata.prepare(
|
||||||
|
seq_group_metadata_list,
|
||||||
|
model_input.seq_lens,
|
||||||
|
model_input.query_lens,
|
||||||
|
self.device,
|
||||||
|
pin_memory=False,
|
||||||
|
generators=generators,
|
||||||
|
cache=self.sampling_metadata_cache)
|
||||||
|
|
||||||
|
return dataclasses.replace(model_input,
|
||||||
|
sampling_metadata=sampling_metadata,
|
||||||
|
virtual_engine=virtual_engine)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def execute_model(
|
||||||
|
self,
|
||||||
|
model_input: ModelInputForXPUWithSamplingMetadata,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
num_steps: int = 1,
|
||||||
|
) -> Optional[List[SamplerOutput]]:
|
||||||
|
if num_steps > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"XPUModelRunner does not support multi-step execution.")
|
||||||
|
|
||||||
|
model_executable = self.model
|
||||||
|
if (self.observability_config is not None
|
||||||
|
and self.observability_config.collect_model_forward_time):
|
||||||
|
model_forward_start_time = time.time()
|
||||||
|
with set_forward_context(model_input.attn_metadata, self.vllm_config,
|
||||||
|
model_input.virtual_engine):
|
||||||
|
hidden_or_intermediate_states = model_executable(
|
||||||
|
input_ids=model_input.input_tokens,
|
||||||
|
positions=model_input.input_positions,
|
||||||
|
intermediate_tensors=intermediate_tensors,
|
||||||
|
**MultiModalKwargs.as_kwargs(
|
||||||
|
model_input.multi_modal_kwargs or {},
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# Compute the logits in the last pipeline stage.
|
||||||
|
if not get_pp_group().is_last_rank:
|
||||||
|
return hidden_or_intermediate_states
|
||||||
|
|
||||||
|
if (self.observability_config is not None
|
||||||
|
and self.observability_config.collect_model_forward_time):
|
||||||
|
model_forward_end_time = time.time()
|
||||||
|
|
||||||
|
# Compute the logits.
|
||||||
|
logits = self.model.compute_logits(hidden_or_intermediate_states,
|
||||||
|
model_input.sampling_metadata)
|
||||||
|
|
||||||
|
# Only perform sampling in the driver worker.
|
||||||
|
if not self.is_driver_worker:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if model_input.async_callback is not None:
|
||||||
|
model_input.async_callback()
|
||||||
|
|
||||||
|
# Sample the next token.
|
||||||
|
output: SamplerOutput = self.sampler(
|
||||||
|
logits=logits,
|
||||||
|
sampling_metadata=model_input.sampling_metadata,
|
||||||
|
)
|
||||||
|
if (self.observability_config is not None
|
||||||
|
and self.observability_config.collect_model_forward_time
|
||||||
|
and output is not None):
|
||||||
|
model_forward_time = (model_forward_end_time -
|
||||||
|
model_forward_start_time)
|
||||||
|
# If there are multiple workers, we are still tracking the latency
|
||||||
|
# from the start time of the driver worker to the end time of the
|
||||||
|
# driver worker. The model forward time will then end up covering
|
||||||
|
# the communication time as well.
|
||||||
|
output.model_forward_time = model_forward_time
|
||||||
|
|
||||||
|
return [output]
|
|
@ -0,0 +1,186 @@
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
"""A XPU worker class."""
|
||||||
|
import gc
|
||||||
|
import os
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import intel_extension_for_pytorch # noqa: F401
|
||||||
|
import oneccl_bindings_for_pytorch # noqa: F401
|
||||||
|
import torch
|
||||||
|
import torch.distributed
|
||||||
|
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||||
|
init_distributed_environment)
|
||||||
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
from vllm.model_executor import set_random_seed
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.worker.cache_engine import CacheEngine
|
||||||
|
from vllm.worker.worker import Worker
|
||||||
|
from vllm.worker.worker_base import LoRANotSupportedWorkerBase, WorkerBase
|
||||||
|
from vllm.worker.xpu_model_runner import XPUModelRunner
|
||||||
|
|
||||||
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class XPUWorker(LoRANotSupportedWorkerBase, Worker):
|
||||||
|
"""A worker class that executes (a partition of) the model on a GPU.
|
||||||
|
|
||||||
|
Each worker is associated with a single XPU device. The worker is
|
||||||
|
responsible for maintaining the KV cache and executing the model on the
|
||||||
|
XPU. In case of distributed inference, each worker is assigned a partition
|
||||||
|
of the model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
local_rank: int,
|
||||||
|
rank: int,
|
||||||
|
distributed_init_method: str,
|
||||||
|
is_driver_worker: bool = False,
|
||||||
|
) -> None:
|
||||||
|
WorkerBase.__init__(self, vllm_config=vllm_config)
|
||||||
|
device_config = self.device_config
|
||||||
|
parallel_config = self.parallel_config
|
||||||
|
assert device_config.device_type == "xpu"
|
||||||
|
assert current_platform.is_xpu()
|
||||||
|
|
||||||
|
self.parallel_config.rank = rank
|
||||||
|
|
||||||
|
self.local_rank = local_rank
|
||||||
|
self.rank = rank
|
||||||
|
self.distributed_init_method = distributed_init_method
|
||||||
|
self.is_driver_worker = is_driver_worker
|
||||||
|
if parallel_config and is_driver_worker:
|
||||||
|
assert rank % parallel_config.tensor_parallel_size == 0, \
|
||||||
|
"Driver worker should be rank 0 of tensor parallel group."
|
||||||
|
|
||||||
|
self.model_runner = XPUModelRunner( # type: ignore
|
||||||
|
vllm_config=vllm_config,
|
||||||
|
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||||
|
is_driver_worker=is_driver_worker,
|
||||||
|
)
|
||||||
|
# Uninitialized cache engine. Will be initialized by
|
||||||
|
# initialize_cache.
|
||||||
|
self.cache_engine: List[CacheEngine]
|
||||||
|
self.gpu_cache: Optional[List[List[torch.Tensor]]]
|
||||||
|
|
||||||
|
def init_device(self) -> None:
|
||||||
|
if self.device_config.device.type == "xpu" and current_platform.is_xpu(
|
||||||
|
):
|
||||||
|
self.device = torch.device(f"xpu:{self.local_rank}")
|
||||||
|
torch.xpu.set_device(self.device)
|
||||||
|
torch.xpu.empty_cache()
|
||||||
|
self.init_gpu_memory = torch.xpu.get_device_properties(
|
||||||
|
self.local_rank).total_memory
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Not support device type: {self.device_config.device}")
|
||||||
|
# Initialize the distributed environment.
|
||||||
|
self.init_worker_distributed_environment()
|
||||||
|
# Initialize the model.
|
||||||
|
set_random_seed(self.model_config.seed)
|
||||||
|
|
||||||
|
# keep this method for `empty_cache` and `synchronize` api
|
||||||
|
@torch.inference_mode()
|
||||||
|
def determine_num_available_blocks(self) -> Tuple[int, int]:
|
||||||
|
"""Profiles the peak memory usage of the model to determine how many
|
||||||
|
KV blocks may be allocated without OOMs.
|
||||||
|
|
||||||
|
The engine will first conduct a profiling of the existing memory usage.
|
||||||
|
Then, it calculate the maximum possible number of GPU and CPU blocks
|
||||||
|
that can be allocated with the remaining free memory.
|
||||||
|
|
||||||
|
Tip:
|
||||||
|
You may limit the usage of GPU memory
|
||||||
|
by adjusting the `gpu_memory_utilization` parameter.
|
||||||
|
"""
|
||||||
|
# Profile the memory usage of the model and get the maximum number of
|
||||||
|
# cache blocks that can be allocated with the remaining free memory.
|
||||||
|
torch.xpu.empty_cache()
|
||||||
|
|
||||||
|
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||||
|
# of the model.
|
||||||
|
self.model_runner.profile_run()
|
||||||
|
|
||||||
|
# Calculate the number of blocks that can be allocated with the
|
||||||
|
# profiled peak memory.
|
||||||
|
torch.xpu.synchronize()
|
||||||
|
used_memory = torch.xpu.memory_allocated()
|
||||||
|
total_gpu_memory = torch.xpu.get_device_properties(
|
||||||
|
self.local_rank).total_memory
|
||||||
|
free_gpu_memory = total_gpu_memory - used_memory
|
||||||
|
|
||||||
|
# NOTE(woosuk): Here we assume that the other processes using the same
|
||||||
|
# GPU did not change their memory usage during the profiling.
|
||||||
|
peak_memory = self.init_gpu_memory - free_gpu_memory
|
||||||
|
assert peak_memory > 0, (
|
||||||
|
"Error in memory profiling. "
|
||||||
|
f"Initial free memory {self.init_gpu_memory}, current free memory"
|
||||||
|
f" {free_gpu_memory}. This happens when the GPU memory was "
|
||||||
|
"not properly cleaned up before initializing the vLLM instance.")
|
||||||
|
|
||||||
|
cache_block_size = self.get_cache_block_size_bytes()
|
||||||
|
num_gpu_blocks = int(
|
||||||
|
(total_gpu_memory * self.cache_config.gpu_memory_utilization -
|
||||||
|
peak_memory) // cache_block_size)
|
||||||
|
num_cpu_blocks = int(self.cache_config.swap_space_bytes //
|
||||||
|
cache_block_size)
|
||||||
|
num_gpu_blocks = max(num_gpu_blocks, 0)
|
||||||
|
num_cpu_blocks = max(num_cpu_blocks, 0)
|
||||||
|
gc.collect()
|
||||||
|
torch.xpu.empty_cache()
|
||||||
|
return num_gpu_blocks, num_cpu_blocks
|
||||||
|
|
||||||
|
def _warm_up_model(self) -> None:
|
||||||
|
# IPEX don't support capture graph yet
|
||||||
|
pass
|
||||||
|
|
||||||
|
def init_worker_distributed_environment(self) -> None:
|
||||||
|
"""Initialize the distributed environment."""
|
||||||
|
|
||||||
|
parallel_config = self.parallel_config
|
||||||
|
rank = self.rank
|
||||||
|
distributed_init_method = self.distributed_init_method
|
||||||
|
|
||||||
|
if torch.distributed.is_initialized():
|
||||||
|
torch_world_size = torch.distributed.get_world_size()
|
||||||
|
if torch_world_size != parallel_config.world_size:
|
||||||
|
raise RuntimeError(
|
||||||
|
"torch.distributed is already initialized but the torch "
|
||||||
|
"world size does not match parallel_config.world_size "
|
||||||
|
f"({torch_world_size} vs. {parallel_config.world_size}).")
|
||||||
|
elif not distributed_init_method:
|
||||||
|
raise ValueError(
|
||||||
|
"distributed_init_method must be set if torch.distributed "
|
||||||
|
"is not already initialized")
|
||||||
|
else:
|
||||||
|
# use sockets as default Level zero IPC exchange backend. By
|
||||||
|
# default oneccl will use `drmfd` as mechanism which need extra
|
||||||
|
# dependency (libdrm and drm headers) on your system.
|
||||||
|
ENV_CCL_ATL_TRANSPORT = os.getenv("CCL_ATL_TRANSPORT", "ofi")
|
||||||
|
ENV_LOCAL_WORLD_SIZE = os.getenv("LOCAL_WORLD_SIZE",
|
||||||
|
str(parallel_config.world_size))
|
||||||
|
os.environ["CCL_ATL_TRANSPORT"] = ENV_CCL_ATL_TRANSPORT
|
||||||
|
os.environ["LOCAL_WORLD_SIZE"] = ENV_LOCAL_WORLD_SIZE
|
||||||
|
os.environ["LOCAL_RANK"] = str(self.local_rank)
|
||||||
|
init_distributed_environment(
|
||||||
|
world_size=parallel_config.world_size,
|
||||||
|
rank=rank,
|
||||||
|
distributed_init_method=distributed_init_method,
|
||||||
|
local_rank=self.local_rank,
|
||||||
|
backend="ccl")
|
||||||
|
|
||||||
|
ensure_model_parallel_initialized(
|
||||||
|
parallel_config.tensor_parallel_size,
|
||||||
|
parallel_config.pipeline_parallel_size)
|
||||||
|
# global all_reduce needed for overall oneccl warm up
|
||||||
|
torch.distributed.all_reduce(torch.zeros(1).xpu())
|
||||||
|
|
||||||
|
if parallel_config.pipeline_parallel_size > 1:
|
||||||
|
# Add pp group init to avoid
|
||||||
|
# p2p communication as the first call
|
||||||
|
get_pp_group().all_reduce(torch.zeros(1).xpu())
|
Loading…
Reference in New Issue