[Frontend] Refactor prompt processing (#4028)

Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
Cyrus Leung 2024-07-23 01:13:53 +08:00 committed by GitHub
parent 89c1c6a196
commit 739b61a348
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 699 additions and 391 deletions

View File

@ -11,7 +11,7 @@ from tqdm import tqdm
from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs
from vllm.inputs import PromptStrictInputs
from vllm.inputs import PromptInputs
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.utils import FlexibleArgumentParser
@ -61,7 +61,7 @@ def main(args: argparse.Namespace):
dummy_prompt_token_ids = np.random.randint(10000,
size=(args.batch_size,
args.input_len))
dummy_inputs: List[PromptStrictInputs] = [{
dummy_inputs: List[PromptInputs] = [{
"prompt_token_ids": batch
} for batch in dummy_prompt_token_ids.tolist()]

View File

@ -8,7 +8,7 @@ Multi-Modality
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.
Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>`
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptStrictInputs`.
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptInputs`.
Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
by following :ref:`this guide <adding_multimodal_plugin>`.

View File

@ -1,7 +1,7 @@
LLM Inputs
==========
.. autodata:: vllm.inputs.PromptStrictInputs
.. autodata:: vllm.inputs.PromptInputs
.. autoclass:: vllm.inputs.TextPrompt
:show-inheritance:

View File

@ -30,7 +30,7 @@ To initialize a VLM, the aforementioned arguments must be passed to the ``LLM``
internally for each model.
To pass an image to the model, note the following in :class:`vllm.inputs.PromptStrictInputs`:
To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`:
* ``prompt``: The prompt should follow the format that is documented on HuggingFace.
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.

View File

@ -35,8 +35,8 @@ def sequence_with_eos(text: str, eos_token: str,
@pytest.mark.parametrize(["text_wo_eos", "eos_token", "eos_token_id"], [
("This text ends with EOS token", "</s>", 2),
])
@pytest.mark.parametrize("ignore_eos", [True, False, None])
@pytest.mark.parametrize("include_stop_str_in_output", [True, False, None])
@pytest.mark.parametrize("ignore_eos", [True, False])
@pytest.mark.parametrize("include_stop_str_in_output", [True, False])
@pytest.mark.skip_global_cleanup
def test_stop_on_eos_token(text_wo_eos: str, eos_token: str, eos_token_id: int,
ignore_eos: bool, include_stop_str_in_output: bool):

View File

@ -32,7 +32,10 @@ async def _async_serving_chat_init():
model_config,
served_model_names=[MODEL_NAME],
response_role="assistant",
chat_template=CHAT_TEMPLATE)
chat_template=CHAT_TEMPLATE,
lora_modules=None,
prompt_adapters=None,
request_logger=None)
return serving_completion

View File

@ -5,7 +5,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptStrictInputs, TextPrompt, TokensPrompt
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (CompletionOutput, EmbeddingOutput,
EmbeddingRequestOutput, RequestOutput)
@ -19,7 +19,7 @@ __all__ = [
"__version__",
"LLM",
"ModelRegistry",
"PromptStrictInputs",
"PromptInputs",
"TextPrompt",
"TokensPrompt",
"SamplingParams",

View File

@ -827,7 +827,6 @@ class AsyncEngineArgs(EngineArgs):
"""Arguments for asynchronous vLLM engine."""
engine_use_ray: bool = False
disable_log_requests: bool = False
max_log_len: Optional[int] = None
@staticmethod
def add_cli_args(parser: FlexibleArgumentParser,
@ -841,12 +840,6 @@ class AsyncEngineArgs(EngineArgs):
parser.add_argument('--disable-log-requests',
action='store_true',
help='Disable logging requests.')
parser.add_argument('--max-log-len',
type=int,
default=None,
help='Max number of prompt characters or prompt '
'ID numbers being printed in log.'
'\n\nDefault: Unlimited')
return parser

View File

@ -1,8 +1,8 @@
import asyncio
import time
from functools import partial
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional,
Set, Tuple, Type, Union)
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Mapping,
Optional, Set, Tuple, Type, Union)
from transformers import PreTrainedTokenizer
@ -151,7 +151,10 @@ class RequestTracker:
logger.info("Finished request %s.", request_id)
self.abort_request(request_id)
def add_request(self, request_id: str,
def add_request(self,
request_id: str,
*,
verbose: bool = False,
**engine_add_request_kwargs) -> AsyncStream:
"""Add a request to be sent to the engine on the next background
loop iteration."""
@ -166,6 +169,9 @@ class RequestTracker:
self.new_requests_event.set()
if verbose:
logger.info("Added request %s.", request_id)
return stream
def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
@ -299,14 +305,14 @@ class _AsyncLLMEngine(LLMEngine):
return self.input_processor(llm_inputs)
async def add_request_async(
self,
request_id: str,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
self,
request_id: str,
inputs: PromptInputs,
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
if lora_request is not None and not self.lora_config:
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
@ -353,8 +359,6 @@ class AsyncLLMEngine:
async frontend will be executed in a separate process as the
model workers.
log_requests: Whether to log the requests.
max_log_len: Maximum number of prompt characters or prompt ID numbers
being printed in log.
start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call.
*args: Arguments for :class:`LLMEngine`.
@ -368,13 +372,11 @@ class AsyncLLMEngine:
engine_use_ray: bool,
*args,
log_requests: bool = True,
max_log_len: Optional[int] = None,
start_engine_loop: bool = True,
**kwargs) -> None:
self.worker_use_ray = worker_use_ray
self.engine_use_ray = engine_use_ray
self.log_requests = log_requests
self.max_log_len = max_log_len
self.engine = self._init_engine(*args, **kwargs)
self.background_loop: Optional[asyncio.Future] = None
@ -468,7 +470,6 @@ class AsyncLLMEngine:
executor_class=executor_class,
log_requests=not engine_args.disable_log_requests,
log_stats=not engine_args.disable_log_stats,
max_log_len=engine_args.max_log_len,
start_engine_loop=start_engine_loop,
usage_context=usage_context,
stat_loggers=stat_loggers,
@ -667,30 +668,9 @@ class AsyncLLMEngine:
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncStream:
if self.log_requests:
if isinstance(inputs, str):
shortened_prompt = inputs
shortened_token_ids = None
else:
shortened_prompt = inputs.get("prompt")
shortened_token_ids = inputs.get("prompt_token_ids")
max_log_len = self.max_log_len
if max_log_len is not None:
if shortened_prompt is not None:
shortened_prompt = shortened_prompt[:max_log_len]
if shortened_token_ids is not None:
shortened_token_ids = shortened_token_ids[:max_log_len]
logger.info(
"Received request %s: prompt: %r, "
"params: %s, prompt_token_ids: %s, "
"lora_request: %s.", request_id, shortened_prompt, params,
shortened_token_ids, lora_request)
if not self.is_running:
if self.start_engine_loop:
self.start_background_loop()
@ -706,6 +686,7 @@ class AsyncLLMEngine:
stream = self._request_tracker.add_request(
request_id,
verbose=self.log_requests,
inputs=inputs,
params=params,
arrival_time=arrival_time,
@ -721,7 +702,7 @@ class AsyncLLMEngine:
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncIterator[RequestOutput]:
"""Generate outputs for a request.
@ -804,7 +785,7 @@ class AsyncLLMEngine:
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
) -> AsyncIterator[EmbeddingRequestOutput]:
"""Generate outputs for a request from an embedding model.
@ -882,7 +863,7 @@ class AsyncLLMEngine:
params: Union[SamplingParams, PoolingParams],
*,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Common logic to process requests with SamplingParams or

View File

@ -1,6 +1,7 @@
import time
from contextlib import contextmanager
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, Optional
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
Mapping, Optional)
from typing import Sequence as GenericSequence
from typing import Set, Type, TypeVar, Union
@ -522,7 +523,7 @@ class LLMEngine:
arrival_time: float,
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Dict[str, str]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
) -> None:
# Create the sequences.
block_size = self.cache_config.block_size
@ -603,7 +604,7 @@ class LLMEngine:
params: Union[SamplingParams, PoolingParams],
arrival_time: Optional[float] = None,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Dict[str, str]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
"""Add a request to the engine's request pool.
@ -677,7 +678,7 @@ class LLMEngine:
sampling_params: SamplingParams,
arrival_time: float,
lora_request: Optional[LoRARequest],
trace_headers: Optional[Dict[str, str]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> SequenceGroup:
"""Creates a SequenceGroup with SamplingParams."""

View File

@ -6,8 +6,7 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine
from vllm.inputs import (PromptInputs, PromptStrictInputs, TextPrompt,
TextTokensPrompt, TokensPrompt,
from vllm.inputs import (PromptInputs, TextPrompt, TokensPrompt,
parse_and_batch_prompt)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@ -238,7 +237,7 @@ class LLM:
@overload
def generate(
self,
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
inputs: Union[PromptInputs, Sequence[PromptInputs]],
/, # We may enable `inputs` keyword after removing the old API
*,
sampling_params: Optional[Union[SamplingParams,
@ -255,7 +254,7 @@ class LLM:
"instead.")
def generate(
self,
prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
Optional[Union[str, List[str]]]] = None,
sampling_params: Optional[Union[SamplingParams,
Sequence[SamplingParams]]] = None,
@ -302,9 +301,7 @@ class LLM:
prompt_token_ids=prompt_token_ids,
)
else:
inputs = cast(
Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
prompts)
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
if sampling_params is None:
# Use default sampling params.
@ -383,7 +380,7 @@ class LLM:
@overload
def encode(
self,
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
inputs: Union[PromptInputs, Sequence[PromptInputs]],
/, # We may enable `inputs` keyword after removing the old API
*,
pooling_params: Optional[Union[PoolingParams,
@ -400,7 +397,7 @@ class LLM:
"instead.")
def encode(
self,
prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
Optional[Union[str, List[str]]]] = None,
pooling_params: Optional[Union[PoolingParams,
Sequence[PoolingParams]]] = None,
@ -417,7 +414,7 @@ class LLM:
Args:
inputs: The inputs to the LLM. You may pass a sequence of inputs for
batch inference. See :class:`~vllm.inputs.PromptStrictInputs`
batch inference. See :class:`~vllm.inputs.PromptInputs`
for more details about the format of each input.
pooling_params: The pooling parameters for pooling. If None, we
use the default pooling parameters.
@ -446,9 +443,7 @@ class LLM:
prompt_token_ids=prompt_token_ids,
)
else:
inputs = cast(
Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
prompts)
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
if pooling_params is None:
# Use default pooling params.
@ -496,17 +491,11 @@ class LLM:
inputs: List[PromptInputs] = []
for i in range(num_requests):
if prompts is not None:
if prompt_token_ids is not None:
item = TextTokensPrompt(
prompt=prompts[i],
prompt_token_ids=prompt_token_ids[i])
else:
item = TextPrompt(prompt=prompts[i])
item = TextPrompt(prompt=prompts[i])
elif prompt_token_ids is not None:
item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
else:
if prompt_token_ids is not None:
item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
else:
raise AssertionError
raise AssertionError
inputs.append(item)
@ -514,7 +503,7 @@ class LLM:
def _validate_and_add_requests(
self,
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
inputs: Union[PromptInputs, Sequence[PromptInputs]],
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
Sequence[PoolingParams]],
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],

View File

@ -0,0 +1,41 @@
from typing import List, Optional, Union
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
logger = init_logger(__name__)
class RequestLogger:
def __init__(self, *, max_log_len: Optional[int]) -> None:
super().__init__()
self.max_log_len = max_log_len
def log_inputs(
self,
request_id: str,
prompt: Optional[str],
prompt_token_ids: Optional[List[int]],
params: Optional[Union[SamplingParams, PoolingParams]],
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> None:
max_log_len = self.max_log_len
if max_log_len is not None:
if prompt is not None:
prompt = prompt[:max_log_len]
if prompt_token_ids is not None:
prompt_token_ids = prompt_token_ids[:max_log_len]
logger.info(
"Received request %s: prompt: %r, "
"params: %s, prompt_token_ids: %s, "
"lora_request: %s, prompt_adapter_request: %s.", request_id,
prompt, params, prompt_token_ids, lora_request,
prompt_adapter_request)

View File

@ -18,6 +18,7 @@ from starlette.routing import Mount
import vllm.envs as envs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.cli_args import make_arg_parser
# yapf conflicts with isort for this block
# yapf: disable
@ -244,24 +245,48 @@ def run_server(args, llm_engine=None):
# When using single vLLM without engine_use_ray
model_config = asyncio.run(engine.get_model_config())
if args.disable_log_requests:
request_logger = None
else:
request_logger = RequestLogger(max_log_len=args.max_log_len)
global openai_serving_chat
global openai_serving_completion
global openai_serving_embedding
global openai_serving_tokenization
openai_serving_chat = OpenAIServingChat(engine, model_config,
served_model_names,
args.response_role,
args.lora_modules,
args.chat_template)
openai_serving_chat = OpenAIServingChat(
engine,
model_config,
served_model_names,
args.response_role,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
chat_template=args.chat_template,
)
openai_serving_completion = OpenAIServingCompletion(
engine, model_config, served_model_names, args.lora_modules,
args.prompt_adapters)
openai_serving_embedding = OpenAIServingEmbedding(engine, model_config,
served_model_names)
engine,
model_config,
served_model_names,
lora_modules=args.lora_modules,
prompt_adapters=args.prompt_adapters,
request_logger=request_logger,
)
openai_serving_embedding = OpenAIServingEmbedding(
engine,
model_config,
served_model_names,
request_logger=request_logger,
)
openai_serving_tokenization = OpenAIServingTokenization(
engine, model_config, served_model_names, args.lora_modules,
args.chat_template)
engine,
model_config,
served_model_names,
lora_modules=args.lora_modules,
request_logger=request_logger,
chat_template=args.chat_template,
)
app.root_path = args.root_path
logger.info("Available routes are:")

View File

@ -130,6 +130,14 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"using app.add_middleware(). ")
parser = AsyncEngineArgs.add_cli_args(parser)
parser.add_argument('--max-log-len',
type=int,
default=None,
help='Max number of prompt characters or prompt '
'ID numbers being printed in log.'
'\n\nDefault: Unlimited')
return parser

View File

@ -121,40 +121,42 @@ class ChatCompletionRequest(OpenAIBaseModel):
# doc: begin-chat-completion-sampling-params
best_of: Optional[int] = None
use_beam_search: Optional[bool] = False
top_k: Optional[int] = -1
min_p: Optional[float] = 0.0
repetition_penalty: Optional[float] = 1.0
length_penalty: Optional[float] = 1.0
early_stopping: Optional[bool] = False
ignore_eos: Optional[bool] = False
min_tokens: Optional[int] = 0
use_beam_search: bool = False
top_k: int = -1
min_p: float = 0.0
repetition_penalty: float = 1.0
length_penalty: float = 1.0
early_stopping: bool = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
include_stop_str_in_output: bool = False
ignore_eos: bool = False
min_tokens: int = 0
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
# doc: end-chat-completion-sampling-params
# doc: begin-chat-completion-extra-params
echo: Optional[bool] = Field(
echo: bool = Field(
default=False,
description=(
"If true, the new message will be prepended with the last message "
"if they belong to the same role."),
)
add_generation_prompt: Optional[bool] = Field(
add_generation_prompt: bool = Field(
default=True,
description=
("If true, the generation prompt will be added to the chat template. "
"This is a parameter used by chat template in tokenizer config of the "
"model."),
)
add_special_tokens: Optional[bool] = Field(
add_special_tokens: bool = Field(
default=False,
description=(
"If true, special tokens (e.g. BOS) will be added to the prompt "
"on top of what is added by the chat template. "
"For most models, the chat template takes care of adding the "
"special tokens so this should be set to False (as is the "
"special tokens so this should be set to false (as is the "
"default)."),
)
documents: Optional[List[Dict[str, str]]] = Field(
@ -178,12 +180,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
description=("Additional kwargs to pass to the template renderer. "
"Will be accessible by the chat template."),
)
include_stop_str_in_output: Optional[bool] = Field(
default=False,
description=(
"Whether to include the stop string in the output. "
"This is only applied when the stop or stop_token_ids is set."),
)
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
default=None,
description=("If specified, the output will follow the JSON schema."),
@ -244,22 +240,22 @@ class ChatCompletionRequest(OpenAIBaseModel):
return SamplingParams(
n=self.n,
best_of=self.best_of,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
top_p=self.top_p,
top_k=self.top_k,
min_p=self.min_p,
seed=self.seed,
stop=self.stop,
stop_token_ids=self.stop_token_ids,
max_tokens=self.max_tokens,
min_tokens=self.min_tokens,
logprobs=self.top_logprobs if self.logprobs else None,
prompt_logprobs=self.top_logprobs if self.echo else None,
best_of=self.best_of,
top_k=self.top_k,
ignore_eos=self.ignore_eos,
max_tokens=self.max_tokens,
min_tokens=self.min_tokens,
use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
skip_special_tokens=self.skip_special_tokens,
@ -267,6 +263,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
logits_processors=logits_processors,
truncate_prompt_tokens=self.truncate_prompt_tokens,
)
@model_validator(mode='before')
@ -348,26 +345,27 @@ class CompletionRequest(OpenAIBaseModel):
user: Optional[str] = None
# doc: begin-completion-sampling-params
use_beam_search: Optional[bool] = False
top_k: Optional[int] = -1
min_p: Optional[float] = 0.0
repetition_penalty: Optional[float] = 1.0
length_penalty: Optional[float] = 1.0
early_stopping: Optional[bool] = False
use_beam_search: bool = False
top_k: int = -1
min_p: float = 0.0
repetition_penalty: float = 1.0
length_penalty: float = 1.0
early_stopping: bool = False
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
ignore_eos: Optional[bool] = False
min_tokens: Optional[int] = 0
skip_special_tokens: Optional[bool] = True
spaces_between_special_tokens: Optional[bool] = True
include_stop_str_in_output: bool = False
ignore_eos: bool = False
min_tokens: int = 0
skip_special_tokens: bool = True
spaces_between_special_tokens: bool = True
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
# doc: end-completion-sampling-params
# doc: begin-completion-extra-params
include_stop_str_in_output: Optional[bool] = Field(
default=False,
add_special_tokens: bool = Field(
default=True,
description=(
"Whether to include the stop string in the output. "
"This is only applied when the stop or stop_token_ids is set."),
"If true (the default), special tokens (e.g. BOS) will be added to "
"the prompt."),
)
response_format: Optional[ResponseFormat] = Field(
default=None,
@ -447,15 +445,15 @@ class CompletionRequest(OpenAIBaseModel):
seed=self.seed,
stop=self.stop,
stop_token_ids=self.stop_token_ids,
logprobs=self.logprobs,
ignore_eos=self.ignore_eos,
max_tokens=self.max_tokens if not echo_without_generation else 1,
min_tokens=self.min_tokens,
logprobs=self.logprobs,
use_beam_search=self.use_beam_search,
early_stopping=self.early_stopping,
prompt_logprobs=self.logprobs if self.echo else None,
skip_special_tokens=self.skip_special_tokens,
spaces_between_special_tokens=(self.spaces_between_special_tokens),
spaces_between_special_tokens=self.spaces_between_special_tokens,
include_stop_str_in_output=self.include_stop_str_in_output,
length_penalty=self.length_penalty,
logits_processors=logits_processors,
@ -489,11 +487,11 @@ class CompletionRequest(OpenAIBaseModel):
def validate_stream_options(cls, data):
if data.get("stream_options") and not data.get("stream"):
raise ValueError(
"Stream options can only be defined when stream is True.")
"Stream options can only be defined when stream is true.")
return data
class EmbeddingRequest(BaseModel):
class EmbeddingRequest(OpenAIBaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/embeddings
model: str
@ -565,13 +563,13 @@ class CompletionStreamResponse(OpenAIBaseModel):
usage: Optional[UsageInfo] = Field(default=None)
class EmbeddingResponseData(BaseModel):
class EmbeddingResponseData(OpenAIBaseModel):
index: int
object: str = "embedding"
embedding: Union[List[float], str]
class EmbeddingResponse(BaseModel):
class EmbeddingResponse(OpenAIBaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
object: str = "list"
created: int = Field(default_factory=lambda: int(time.time()))
@ -670,8 +668,8 @@ class BatchRequestInput(OpenAIBaseModel):
# /v1/chat/completions is supported.
url: str
# The parameteters of the request.
body: Union[ChatCompletionRequest, ]
# The parameters of the request.
body: ChatCompletionRequest
class BatchResponseData(OpenAIBaseModel):
@ -703,12 +701,22 @@ class BatchRequestOutput(OpenAIBaseModel):
error: Optional[Any]
class TokenizeRequest(OpenAIBaseModel):
class TokenizeCompletionRequest(OpenAIBaseModel):
model: str
prompt: str
add_special_tokens: bool = Field(default=True)
class TokenizeChatRequest(OpenAIBaseModel):
model: str
messages: List[ChatCompletionMessageParam]
add_generation_prompt: bool = Field(default=True)
add_special_tokens: bool = Field(default=False)
prompt: Optional[str] = Field(default=None)
messages: Optional[List[ChatCompletionMessageParam]] = Field(default=None)
model: str
TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
class TokenizeResponse(OpenAIBaseModel):

View File

@ -6,6 +6,7 @@ import aiohttp
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (BatchRequestInput,
BatchRequestOutput,
BatchResponseData,
@ -44,9 +45,17 @@ def parse_args():
type=nullable_str,
default="assistant",
help="The role name to return if "
"`request.add_generation_prompt=true`.")
"`request.add_generation_prompt=True`.")
parser = AsyncEngineArgs.add_cli_args(parser)
parser.add_argument('--max-log-len',
type=int,
default=None,
help='Max number of prompt characters or prompt '
'ID numbers being printed in log.'
'\n\nDefault: Unlimited')
return parser.parse_args()
@ -114,11 +123,20 @@ async def main(args):
# When using single vLLM without engine_use_ray
model_config = await engine.get_model_config()
if args.disable_log_requests:
request_logger = None
else:
request_logger = RequestLogger(max_log_len=args.max_log_len)
openai_serving_chat = OpenAIServingChat(
engine,
model_config,
served_model_names,
args.response_role,
lora_modules=None,
prompt_adapters=None,
request_logger=request_logger,
chat_template=None,
)
# Submit all requests in the file to the engine "concurrently".

View File

@ -12,6 +12,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.chat_utils import (ConversationMessage,
load_chat_template,
parse_chat_message_content)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
ChatCompletionLogProb, ChatCompletionLogProbs,
ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam,
@ -20,7 +21,8 @@ from vllm.entrypoints.openai.protocol import (
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
FunctionCall, ToolCall, UsageInfo)
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
OpenAIServing,
PromptAdapterPath)
from vllm.inputs import PromptInputs
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import (
@ -37,17 +39,24 @@ logger = init_logger(__name__)
class OpenAIServingChat(OpenAIServing):
def __init__(self,
engine: AsyncLLMEngine,
model_config: ModelConfig,
served_model_names: List[str],
response_role: str,
lora_modules: Optional[List[LoRAModulePath]] = None,
chat_template: Optional[str] = None):
def __init__(
self,
engine: AsyncLLMEngine,
model_config: ModelConfig,
served_model_names: List[str],
response_role: str,
*,
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
):
super().__init__(engine=engine,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules)
lora_modules=lora_modules,
prompt_adapters=prompt_adapters,
request_logger=request_logger)
self.response_role = response_role
@ -74,7 +83,12 @@ class OpenAIServingChat(OpenAIServing):
return error_check_ret
try:
_, lora_request = self._maybe_get_adapter(request)
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
model_config = self.model_config
tokenizer = await self.engine.get_tokenizer(lora_request)
conversation: List[ConversationMessage] = []
@ -82,7 +96,7 @@ class OpenAIServingChat(OpenAIServing):
for msg in request.messages:
chat_parsed_result = parse_chat_message_content(
msg, self.model_config, tokenizer)
msg, model_config, tokenizer)
conversation.extend(chat_parsed_result.messages)
mm_futures.extend(chat_parsed_result.mm_futures)
@ -116,14 +130,8 @@ class OpenAIServingChat(OpenAIServing):
logger.error("Error in loading multi-modal data: %s", e)
return self.create_error_response(str(e))
request_id = f"cmpl-{random_uuid()}"
request_id = f"chat-{random_uuid()}"
try:
# Tokenize/detokenize depending on prompt format (string/token list)
prompt_ids, prompt_text = await self._validate_prompt_and_tokenize(
request,
tokenizer,
prompt=prompt,
add_special_tokens=request.add_special_tokens)
sampling_params = request.to_sampling_params()
decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
@ -137,31 +145,47 @@ class OpenAIServingChat(OpenAIServing):
sampling_params.logits_processors = []
sampling_params.logits_processors.append(
guided_decode_logits_processor)
prompt_inputs = self._tokenize_prompt_input(
request,
tokenizer,
prompt,
truncate_prompt_tokens=sampling_params.truncate_prompt_tokens,
add_special_tokens=request.add_special_tokens,
)
self._log_inputs(request_id,
prompt_inputs,
params=sampling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
engine_inputs: PromptInputs = {
"prompt_token_ids": prompt_inputs["prompt_token_ids"],
}
if mm_data is not None:
engine_inputs["multi_modal_data"] = mm_data
is_tracing_enabled = await self.engine.is_tracing_enabled()
trace_headers = None
if is_tracing_enabled and raw_request:
trace_headers = extract_trace_headers(raw_request.headers)
if (not is_tracing_enabled and raw_request
and contains_trace_headers(raw_request.headers)):
log_tracing_disabled_warning()
result_generator = self.engine.generate(
engine_inputs,
sampling_params,
request_id,
lora_request=lora_request,
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
inputs: PromptInputs = {
"prompt": prompt_text,
"prompt_token_ids": prompt_ids,
}
if mm_data:
inputs["multi_modal_data"] = mm_data
is_tracing_enabled = await self.engine.is_tracing_enabled()
trace_headers = None
if is_tracing_enabled and raw_request:
trace_headers = extract_trace_headers(raw_request.headers)
if not is_tracing_enabled and raw_request and contains_trace_headers(
raw_request.headers):
log_tracing_disabled_warning()
result_generator = self.engine.generate(
inputs,
sampling_params,
request_id,
lora_request,
trace_headers=trace_headers,
)
# Streaming response
if request.stream:
return self.chat_completion_stream_generator(
@ -195,10 +219,11 @@ class OpenAIServingChat(OpenAIServing):
first_iteration = True
# Send response for each token for each request.n (index)
assert request.n is not None
previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
finish_reason_sent = [False] * request.n
num_choices = 1 if request.n is None else request.n
previous_texts = [""] * num_choices
previous_num_tokens = [0] * num_choices
finish_reason_sent = [False] * num_choices
try:
async for res in result_generator:
# We need to do it here, because if there are exceptions in
@ -208,7 +233,7 @@ class OpenAIServingChat(OpenAIServing):
# Send first response for each request.n (index) with
# the role
role = self.get_chat_request_role(request)
for i in range(request.n):
for i in range(num_choices):
choice_data = ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(role=role),
@ -236,19 +261,19 @@ class OpenAIServingChat(OpenAIServing):
last_msg_content = conversation[-1]["content"]
if last_msg_content:
for i in range(request.n):
for i in range(num_choices):
choice_data = (
ChatCompletionResponseStreamChoice(
index=i,
delta=DeltaMessage(
content=last_msg_content),
logprobs=None,
finish_reason=None))
chunk = ChatCompletionStreamResponse(
id=request_id,
object=chunk_object_type,
created=created_time,
choices=[choice_data],
logprobs=None,
model=model_name)
if (request.stream_options and
request.stream_options.include_usage):

View File

@ -2,13 +2,14 @@ import time
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
Optional)
from typing import Sequence as GenericSequence
from typing import Tuple
from typing import Tuple, cast
from fastapi import Request
from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
@ -39,40 +40,24 @@ TypeCreateLogProbsFn = Callable[
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs]
def parse_prompt_format(prompt) -> Tuple[bool, list]:
# get the prompt, openai supports the following
# "a string, array of strings, array of tokens, or array of token arrays."
prompt_is_tokens = False
prompts = [prompt] # case 1: a string
if isinstance(prompt, list):
if len(prompt) == 0:
raise ValueError("please provide at least one prompt")
elif isinstance(prompt[0], str):
prompt_is_tokens = False
prompts = prompt # case 2: array of strings
elif isinstance(prompt[0], int):
prompt_is_tokens = True
prompts = [prompt] # case 3: array of tokens
elif isinstance(prompt[0], list) and isinstance(prompt[0][0], int):
prompt_is_tokens = True
prompts = prompt # case 4: array of token arrays
else:
raise ValueError("prompt must be a string, array of strings, "
"array of tokens, or array of token arrays")
return prompt_is_tokens, prompts
class OpenAIServingCompletion(OpenAIServing):
def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]]):
def __init__(
self,
engine: AsyncLLMEngine,
model_config: ModelConfig,
served_model_names: List[str],
*,
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
):
super().__init__(engine=engine,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules,
prompt_adapters=prompt_adapters)
prompt_adapters=prompt_adapters,
request_logger=request_logger)
async def create_completion(self, request: CompletionRequest,
raw_request: Request):
@ -101,12 +86,11 @@ class OpenAIServingCompletion(OpenAIServing):
# Schedule the request and get the result generator.
generators: List[AsyncIterator[RequestOutput]] = []
try:
adapter_type, adapter_request = self._maybe_get_adapter(request)
lora_request, prompt_adapter_request = None, None
if adapter_type == 'LoRA':
lora_request, prompt_adapter_request = adapter_request, None
elif adapter_type == 'PromptAdapter':
lora_request, prompt_adapter_request = None, adapter_request
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine.get_tokenizer(lora_request)
sampling_params = request.to_sampling_params()
@ -122,17 +106,25 @@ class OpenAIServingCompletion(OpenAIServing):
sampling_params.logits_processors = []
sampling_params.logits_processors.append(
guided_decode_logit_processor)
prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
for i, prompt in enumerate(prompts):
prompt_arg = "prompt_ids" if prompt_is_tokens else "prompt"
prompt_formats = await self._validate_prompt_and_tokenize(
prompts = list(
self._tokenize_prompt_input_or_inputs(
request,
tokenizer,
request.prompt,
truncate_prompt_tokens=sampling_params.
truncate_prompt_tokens,
**{prompt_arg: prompt})
prompt_ids, prompt_text = prompt_formats
add_special_tokens=request.add_special_tokens,
))
for i, prompt_inputs in enumerate(prompts):
request_id_item = f"{request_id}-{i}"
self._log_inputs(request_id_item,
prompt_inputs,
params=sampling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
is_tracing_enabled = await self.engine.is_tracing_enabled()
trace_headers = None
@ -143,12 +135,9 @@ class OpenAIServingCompletion(OpenAIServing):
log_tracing_disabled_warning()
generator = self.engine.generate(
{
"prompt": prompt_text,
"prompt_token_ids": prompt_ids
},
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
sampling_params,
f"{request_id}-{i}",
request_id_item,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
@ -189,9 +178,27 @@ class OpenAIServingCompletion(OpenAIServing):
await self.engine.abort(f"{request_id}-{i}")
return self.create_error_response("Client disconnected")
final_res_batch[i] = res
for i, final_res in enumerate(final_res_batch):
assert final_res is not None
# The output should contain the input text
# We did not pass it into vLLM engine to avoid being redundant
# with the inputs token IDs
if final_res.prompt is None:
final_res.prompt = prompts[i]["prompt"]
final_res_batch_checked = cast(List[RequestOutput],
final_res_batch)
response = self.request_output_to_completion_response(
final_res_batch, request, request_id, created_time, model_name,
tokenizer)
final_res_batch_checked,
request,
request_id,
created_time,
model_name,
tokenizer,
)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error
return self.create_error_response(str(e))
@ -220,10 +227,10 @@ class OpenAIServingCompletion(OpenAIServing):
num_prompts: int,
tokenizer: PreTrainedTokenizer,
) -> AsyncGenerator[str, None]:
assert request.n is not None
previous_texts = [""] * request.n * num_prompts
previous_num_tokens = [0] * request.n * num_prompts
has_echoed = [False] * request.n * num_prompts
num_choices = 1 if request.n is None else request.n
previous_texts = [""] * num_choices * num_prompts
previous_num_tokens = [0] * num_choices * num_prompts
has_echoed = [False] * num_choices * num_prompts
try:
async for prompt_idx, res in result_generator:
@ -234,7 +241,7 @@ class OpenAIServingCompletion(OpenAIServing):
raise StopAsyncIteration()
for output in res.outputs:
i = output.index + prompt_idx * request.n
i = output.index + prompt_idx * num_choices
# TODO(simon): optimize the performance by avoiding full
# text O(n^2) sending.
@ -343,8 +350,8 @@ class OpenAIServingCompletion(OpenAIServing):
choices: List[CompletionResponseChoice] = []
num_prompt_tokens = 0
num_generated_tokens = 0
for final_res in final_res_batch:
assert final_res is not None
prompt_token_ids = final_res.prompt_token_ids
prompt_logprobs = final_res.prompt_logprobs
prompt_text = final_res.prompt

View File

@ -1,16 +1,16 @@
import base64
import time
from typing import AsyncIterator, List, Optional, Tuple
from typing import AsyncIterator, List, Optional, Tuple, cast
import numpy as np
from fastapi import Request
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
EmbeddingResponse,
EmbeddingResponseData, UsageInfo)
from vllm.entrypoints.openai.serving_completion import parse_prompt_format
from vllm.entrypoints.openai.serving_engine import OpenAIServing
from vllm.logger import init_logger
from vllm.outputs import EmbeddingRequestOutput
@ -28,11 +28,11 @@ def request_output_to_embedding_response(
data: List[EmbeddingResponseData] = []
num_prompt_tokens = 0
for idx, final_res in enumerate(final_res_batch):
assert final_res is not None
prompt_token_ids = final_res.prompt_token_ids
embedding = final_res.outputs.embedding
if encoding_format == "base64":
embedding = base64.b64encode(np.array(embedding))
embedding_bytes = np.array(embedding).tobytes()
embedding = base64.b64encode(embedding_bytes).decode("utf-8")
embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
data.append(embedding_data)
@ -54,12 +54,20 @@ def request_output_to_embedding_response(
class OpenAIServingEmbedding(OpenAIServing):
def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
served_model_names: List[str]):
def __init__(
self,
engine: AsyncLLMEngine,
model_config: ModelConfig,
served_model_names: List[str],
*,
request_logger: Optional[RequestLogger],
):
super().__init__(engine=engine,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=None)
lora_modules=None,
prompt_adapters=None,
request_logger=request_logger)
self._check_embedding_mode(model_config.embedding_mode)
async def create_embedding(self, request: EmbeddingRequest,
@ -80,29 +88,47 @@ class OpenAIServingEmbedding(OpenAIServing):
"dimensions is currently not supported")
model_name = request.model
request_id = f"cmpl-{random_uuid()}"
request_id = f"embd-{random_uuid()}"
created_time = int(time.monotonic())
# Schedule the request and get the result generator.
generators = []
generators: List[AsyncIterator[EmbeddingRequestOutput]] = []
try:
prompt_is_tokens, prompts = parse_prompt_format(request.input)
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine.get_tokenizer(lora_request)
pooling_params = request.to_pooling_params()
tokenizer = await self.engine.get_tokenizer()
for i, prompt in enumerate(prompts):
prompt_arg = "prompt_ids" if prompt_is_tokens else "prompt"
prompt_formats = await self._validate_prompt_and_tokenize(
request, tokenizer, **{prompt_arg: prompt})
prompt_ids, prompt_text = prompt_formats
prompts = list(
self._tokenize_prompt_input_or_inputs(
request,
tokenizer,
request.input,
))
for i, prompt_inputs in enumerate(prompts):
request_id_item = f"{request_id}-{i}"
self._log_inputs(request_id_item,
prompt_inputs,
params=pooling_params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
if prompt_adapter_request is not None:
raise NotImplementedError(
"Prompt adapter is not supported "
"for embedding models")
generator = self.engine.encode(
{
"prompt": prompt_text,
"prompt_token_ids": prompt_ids
},
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
pooling_params,
f"{request_id}-{i}",
request_id_item,
lora_request=lora_request,
)
generators.append(generator)
@ -121,11 +147,17 @@ class OpenAIServingEmbedding(OpenAIServing):
if await raw_request.is_disconnected():
# Abort the request if the client disconnects.
await self.engine.abort(f"{request_id}-{i}")
# TODO: Use a vllm-specific Validation Error
return self.create_error_response("Client disconnected")
final_res_batch[i] = res
for final_res in final_res_batch:
assert final_res is not None
final_res_batch_checked = cast(List[EmbeddingRequestOutput],
final_res_batch)
response = request_output_to_embedding_response(
final_res_batch, request_id, created_time, model_name,
final_res_batch_checked, request_id, created_time, model_name,
encoding_format)
except ValueError as e:
# TODO: Use a vllm-specific Validation Error

View File

@ -2,23 +2,33 @@ import json
import pathlib
from dataclasses import dataclass
from http import HTTPStatus
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union
from pydantic import Field
from transformers import PreTrainedTokenizer
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from typing_extensions import Annotated
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.logger import RequestLogger
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest,
DetokenizeRequest,
EmbeddingRequest, ErrorResponse,
ModelCard, ModelList,
ModelPermission, TokenizeRequest)
ModelPermission,
TokenizeChatRequest,
TokenizeCompletionRequest,
TokenizeRequest)
# yapf: enable
from vllm.inputs import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob
logger = init_logger(__name__)
@ -36,6 +46,17 @@ class LoRAModulePath:
local_path: str
AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
EmbeddingRequest, TokenizeRequest]
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
class TextTokensPrompt(TypedDict):
prompt: str
prompt_token_ids: List[int]
class OpenAIServing:
def __init__(
@ -43,8 +64,10 @@ class OpenAIServing:
engine: AsyncLLMEngine,
model_config: ModelConfig,
served_model_names: List[str],
*,
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]] = None,
prompt_adapters: Optional[List[PromptAdapterPath]],
request_logger: Optional[RequestLogger],
):
super().__init__()
@ -78,6 +101,8 @@ class OpenAIServing:
prompt_adapter_local_path=prompt_adapter.local_path,
prompt_adapter_num_virtual_tokens=num_virtual_tokens))
self.request_logger = request_logger
async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model."""
model_cards = [
@ -126,9 +151,8 @@ class OpenAIServing:
return json_str
async def _check_model(
self, request: Union[ChatCompletionRequest, CompletionRequest,
DetokenizeRequest, EmbeddingRequest,
TokenizeRequest]
self,
request: AnyRequest,
) -> Optional[ErrorResponse]:
if request.model in self.served_model_names:
return None
@ -144,64 +168,65 @@ class OpenAIServing:
err_type="NotFoundError",
status_code=HTTPStatus.NOT_FOUND)
def _maybe_get_adapter(
self, request: Union[CompletionRequest, ChatCompletionRequest,
EmbeddingRequest, TokenizeRequest,
DetokenizeRequest]
) -> Tuple[Optional[str], Optional[Union[LoRARequest,
PromptAdapterRequest]]]:
def _maybe_get_adapters(
self, request: AnyRequest
) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[
None, PromptAdapterRequest]]:
if request.model in self.served_model_names:
return None, None
for lora in self.lora_requests:
if request.model == lora.lora_name:
return 'LoRA', lora
return lora, None
for prompt_adapter in self.prompt_adapter_requests:
if request.model == prompt_adapter.prompt_adapter_name:
return 'PromptAdapter', prompt_adapter
return None, prompt_adapter
# if _check_model has been called earlier, this will be unreachable
raise ValueError(f"The model `{request.model}` does not exist.")
async def _validate_prompt_and_tokenize(
self,
request: Union[ChatCompletionRequest, CompletionRequest,
DetokenizeRequest, EmbeddingRequest,
TokenizeRequest],
tokenizer: "PreTrainedTokenizer",
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None,
truncate_prompt_tokens: Optional[Annotated[int,
Field(ge=1)]] = None,
add_special_tokens: Optional[bool] = True
) -> Tuple[List[int], str]:
if not (prompt or prompt_ids):
raise ValueError("Either prompt or prompt_ids should be provided.")
if prompt and prompt_ids:
raise ValueError(
"Only one of prompt or prompt_ids should be provided.")
if prompt_ids is None:
# When using OpenAIServingChat for chat completions, for
# most models the special tokens (e.g., BOS) have already
# been added by the chat template. Therefore, we do not
# need to add them again.
# Set add_special_tokens to False (by default) to avoid
# adding the BOS tokens again.
tokenizer_kwargs: Dict[str, Any] = {
"add_special_tokens": add_special_tokens
}
if truncate_prompt_tokens is not None:
tokenizer_kwargs.update({
"truncation": True,
"max_length": truncate_prompt_tokens,
})
input_ids = tokenizer(prompt, **tokenizer_kwargs).input_ids
elif truncate_prompt_tokens is not None:
input_ids = prompt_ids[-truncate_prompt_tokens:]
def _normalize_prompt_text_to_input(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt: str,
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
add_special_tokens: bool,
) -> TextTokensPrompt:
if truncate_prompt_tokens is None:
encoded = tokenizer(prompt, add_special_tokens=add_special_tokens)
else:
input_ids = prompt_ids
encoded = tokenizer(prompt,
add_special_tokens=add_special_tokens,
truncation=True,
max_length=truncate_prompt_tokens)
input_text = prompt if prompt is not None else tokenizer.decode(
input_ids)
input_ids = encoded.input_ids
input_text = prompt
return self._validate_input(request, input_ids, input_text)
def _normalize_prompt_tokens_to_input(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt_ids: List[int],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
) -> TextTokensPrompt:
if truncate_prompt_tokens is None:
input_ids = prompt_ids
else:
input_ids = prompt_ids[-truncate_prompt_tokens:]
input_text = tokenizer.decode(input_ids)
return self._validate_input(request, input_ids, input_text)
def _validate_input(
self,
request: AnyRequest,
input_ids: List[int],
input_text: str,
) -> TextTokensPrompt:
token_num = len(input_ids)
# Note: EmbeddingRequest doesn't have max_tokens
@ -211,13 +236,16 @@ class OpenAIServing:
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
f"{token_num} tokens in the input for embedding "
f"generation. Please reduce the length of the input.", )
return input_ids, input_text
f"generation. Please reduce the length of the input.")
return TextTokensPrompt(prompt=input_text,
prompt_token_ids=input_ids)
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
# and does not require model context length validation
if isinstance(request, (TokenizeRequest, DetokenizeRequest)):
return input_ids, input_text
if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
DetokenizeRequest)):
return TextTokensPrompt(prompt=input_text,
prompt_token_ids=input_ids)
if request.max_tokens is None:
if token_num >= self.max_model_len:
@ -225,7 +253,7 @@ class OpenAIServing:
f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested "
f"{token_num} tokens in the messages, "
f"Please reduce the length of the messages.", )
f"Please reduce the length of the messages.")
request.max_tokens = self.max_model_len - token_num
if token_num + request.max_tokens > self.max_model_len:
@ -235,13 +263,132 @@ class OpenAIServing:
f"{request.max_tokens + token_num} tokens "
f"({token_num} in the messages, "
f"{request.max_tokens} in the completion). "
f"Please reduce the length of the messages or completion.", )
f"Please reduce the length of the messages or completion.")
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
def _tokenize_prompt_input(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt_input: Union[str, List[int]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = True,
) -> TextTokensPrompt:
"""
A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
that assumes single input.
"""
return next(
self._tokenize_prompt_inputs(
request,
tokenizer,
[prompt_input],
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
))
def _tokenize_prompt_inputs(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
prompt_inputs: Iterable[Union[str, List[int]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = True,
) -> Iterator[TextTokensPrompt]:
"""
A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
that assumes multiple inputs.
"""
for text in prompt_inputs:
if isinstance(text, str):
yield self._normalize_prompt_text_to_input(
request,
tokenizer,
prompt=text,
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)
else:
yield self._normalize_prompt_tokens_to_input(
request,
tokenizer,
prompt_ids=text,
truncate_prompt_tokens=truncate_prompt_tokens,
)
def _tokenize_prompt_input_or_inputs(
self,
request: AnyRequest,
tokenizer: AnyTokenizer,
input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
add_special_tokens: bool = True,
) -> Iterator[TextTokensPrompt]:
"""
Tokenize/detokenize depending on the input format.
According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
, each input can be a string or array of tokens. Note that each request
can pass one or more inputs.
"""
for prompt_input in parse_and_batch_prompt(input_or_inputs):
# Although our type checking is based on mypy,
# VSCode Pyright extension should still work properly
# "is True" is required for Pyright to perform type narrowing
# See: https://github.com/microsoft/pyright/issues/7672
if prompt_input["is_tokens"] is False:
yield self._normalize_prompt_text_to_input(
request,
tokenizer,
prompt=prompt_input["content"],
truncate_prompt_tokens=truncate_prompt_tokens,
add_special_tokens=add_special_tokens,
)
else:
yield self._normalize_prompt_tokens_to_input(
request,
tokenizer,
prompt_ids=prompt_input["content"],
truncate_prompt_tokens=truncate_prompt_tokens,
)
def _log_inputs(
self,
request_id: str,
inputs: Union[str, List[int], TextTokensPrompt],
params: Optional[Union[SamplingParams, PoolingParams]],
lora_request: Optional[LoRARequest],
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> None:
if self.request_logger is None:
return
if isinstance(inputs, str):
prompt = inputs
prompt_token_ids = None
elif isinstance(inputs, list):
prompt = None
prompt_token_ids = inputs
else:
return input_ids, input_text
prompt = inputs["prompt"]
prompt_token_ids = inputs["prompt_token_ids"]
self.request_logger.log_inputs(
request_id,
prompt,
prompt_token_ids,
params=params,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
)
@staticmethod
def _get_decoded_token(logprob: Logprob, token_id: int,
tokenizer: PreTrainedTokenizer) -> str:
def _get_decoded_token(
logprob: Logprob,
token_id: int,
tokenizer: AnyTokenizer,
) -> str:
if logprob.decoded_token is not None:
return logprob.decoded_token
return tokenizer.decode(token_id)

View File

@ -1,83 +1,135 @@
from typing import List, Optional
from typing import List, Optional, Union
from vllm.config import ModelConfig
from vllm.engine.async_llm_engine import AsyncLLMEngine
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.chat_utils import (ConversationMessage,
load_chat_template,
parse_chat_message_content)
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
DetokenizeResponse,
ErrorResponse,
TokenizeChatRequest,
TokenizeRequest,
TokenizeResponse)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
from vllm.utils import random_uuid
class OpenAIServingTokenization(OpenAIServing):
def __init__(self,
engine: AsyncLLMEngine,
model_config: ModelConfig,
served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]] = None,
chat_template: Optional[str] = None):
def __init__(
self,
engine: AsyncLLMEngine,
model_config: ModelConfig,
served_model_names: List[str],
*,
lora_modules: Optional[List[LoRAModulePath]],
request_logger: Optional[RequestLogger],
chat_template: Optional[str],
):
super().__init__(engine=engine,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules)
lora_modules=lora_modules,
prompt_adapters=None,
request_logger=request_logger)
# If this is None we use the tokenizer's default chat template
self.chat_template = load_chat_template(chat_template)
async def create_tokenize(self,
request: TokenizeRequest) -> TokenizeResponse:
async def create_tokenize(
self,
request: TokenizeRequest,
) -> Union[TokenizeResponse, ErrorResponse]:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
if not (request.prompt or request.messages):
return self.create_error_response(
"Either `prompt` or `messages` should be provided.")
request_id = f"tokn-{random_uuid()}"
if (request.prompt and request.messages):
return self.create_error_response(
"Only one of `prompt` or `messages` should be provided.")
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
_, lora_request = self._maybe_get_adapter(request)
tokenizer = await self.engine.get_tokenizer(lora_request)
if request.messages:
if isinstance(request, TokenizeChatRequest):
model_config = self.model_config
conversation: List[ConversationMessage] = []
for message in request.messages:
result = parse_chat_message_content(message, self.model_config,
result = parse_chat_message_content(message, model_config,
tokenizer)
conversation.extend(result.messages)
request.prompt = tokenizer.apply_chat_template(
prompt = tokenizer.apply_chat_template(
add_generation_prompt=request.add_generation_prompt,
conversation=conversation,
tokenize=False,
chat_template=self.chat_template)
assert isinstance(prompt, str)
else:
prompt = request.prompt
(input_ids, input_text) = await self._validate_prompt_and_tokenize(
self._log_inputs(request_id,
prompt,
params=None,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
# Silently ignore prompt adapter since it does not affect tokenization
prompt_input = self._tokenize_prompt_input(
request,
tokenizer,
prompt=request.prompt,
add_special_tokens=request.add_special_tokens)
prompt,
add_special_tokens=request.add_special_tokens,
)
input_ids = prompt_input["prompt_token_ids"]
return TokenizeResponse(tokens=input_ids,
count=len(input_ids),
max_model_len=self.max_model_len)
async def create_detokenize(
self, request: DetokenizeRequest) -> DetokenizeResponse:
self,
request: DetokenizeRequest,
) -> Union[DetokenizeResponse, ErrorResponse]:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret
_, lora_request = self._maybe_get_adapter(request)
request_id = f"tokn-{random_uuid()}"
(
lora_request,
prompt_adapter_request,
) = self._maybe_get_adapters(request)
tokenizer = await self.engine.get_tokenizer(lora_request)
(input_ids, input_text) = await self._validate_prompt_and_tokenize(
request, tokenizer, prompt_ids=request.tokens)
self._log_inputs(request_id,
request.tokens,
params=None,
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request)
if prompt_adapter_request is not None:
raise NotImplementedError("Prompt adapter is not supported "
"for tokenization")
prompt_input = self._tokenize_prompt_input(
request,
tokenizer,
request.tokens,
)
input_text = prompt_input["prompt"]
return DetokenizeResponse(prompt=input_text)

View File

@ -1,6 +1,5 @@
from .data import (LLMInputs, ParsedText, ParsedTokens, PromptInputs,
PromptStrictInputs, TextPrompt, TextTokensPrompt,
TokensPrompt, parse_and_batch_prompt)
TextPrompt, TokensPrompt, parse_and_batch_prompt)
from .registry import InputContext, InputRegistry
INPUT_REGISTRY = InputRegistry()
@ -14,6 +13,6 @@ See also:
__all__ = [
"ParsedText", "ParsedTokens", "parse_and_batch_prompt", "TextPrompt",
"TokensPrompt", "TextTokensPrompt", "PromptStrictInputs", "PromptInputs",
"LLMInputs", "INPUT_REGISTRY", "InputContext", "InputRegistry"
"TokensPrompt", "PromptInputs", "LLMInputs", "INPUT_REGISTRY",
"InputContext", "InputRegistry"
]

View File

@ -92,25 +92,7 @@ class TokensPrompt(TypedDict):
"""
class TextTokensPrompt(TypedDict):
"""It is assumed that :attr:`prompt` is consistent with
:attr:`prompt_token_ids`. This is currently used in
:class:`AsyncLLMEngine` for logging both the text and token IDs."""
prompt: str
"""The prompt text."""
prompt_token_ids: List[int]
"""The token IDs of the prompt."""
multi_modal_data: NotRequired["MultiModalDataDict"]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
PromptStrictInputs = Union[str, TextPrompt, TokensPrompt]
PromptInputs = Union[str, TextPrompt, TokensPrompt]
"""
The inputs to the LLM, which can take one of the following forms:
@ -118,10 +100,6 @@ The inputs to the LLM, which can take one of the following forms:
- A tokenized prompt (:class:`TokensPrompt`)
"""
PromptInputs = Union[str, TextPrompt, TokensPrompt, TextTokensPrompt]
"""Same as :const:`PromptStrictInputs` but additionally accepts
:class:`TextTokensPrompt`."""
class LLMInputs(TypedDict):
"""

View File

@ -5,7 +5,8 @@ import math
from abc import ABC, abstractmethod
from collections import defaultdict
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
Union)
import torch
@ -438,7 +439,7 @@ class SequenceGroup:
embeddings: Optional[List[float]] = None,
pooling_params: Optional[PoolingParams] = None,
encoder_seq: Optional[Sequence] = None,
trace_headers: Optional[Dict[str, str]] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
) -> None:
self.request_id = request_id