mirror of https://github.com/vllm-project/vllm.git
[Frontend] Refactor prompt processing (#4028)
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
parent
89c1c6a196
commit
739b61a348
|
@ -11,7 +11,7 @@ from tqdm import tqdm
|
|||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.inputs import PromptStrictInputs
|
||||
from vllm.inputs import PromptInputs
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
@ -61,7 +61,7 @@ def main(args: argparse.Namespace):
|
|||
dummy_prompt_token_ids = np.random.randint(10000,
|
||||
size=(args.batch_size,
|
||||
args.input_len))
|
||||
dummy_inputs: List[PromptStrictInputs] = [{
|
||||
dummy_inputs: List[PromptInputs] = [{
|
||||
"prompt_token_ids": batch
|
||||
} for batch in dummy_prompt_token_ids.tolist()]
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ Multi-Modality
|
|||
vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package.
|
||||
|
||||
Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models <supported_vlms>`
|
||||
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptStrictInputs`.
|
||||
via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptInputs`.
|
||||
|
||||
Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities
|
||||
by following :ref:`this guide <adding_multimodal_plugin>`.
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
LLM Inputs
|
||||
==========
|
||||
|
||||
.. autodata:: vllm.inputs.PromptStrictInputs
|
||||
.. autodata:: vllm.inputs.PromptInputs
|
||||
|
||||
.. autoclass:: vllm.inputs.TextPrompt
|
||||
:show-inheritance:
|
||||
|
|
|
@ -30,7 +30,7 @@ To initialize a VLM, the aforementioned arguments must be passed to the ``LLM``
|
|||
internally for each model.
|
||||
|
||||
|
||||
To pass an image to the model, note the following in :class:`vllm.inputs.PromptStrictInputs`:
|
||||
To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`:
|
||||
|
||||
* ``prompt``: The prompt should follow the format that is documented on HuggingFace.
|
||||
* ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`.
|
||||
|
|
|
@ -35,8 +35,8 @@ def sequence_with_eos(text: str, eos_token: str,
|
|||
@pytest.mark.parametrize(["text_wo_eos", "eos_token", "eos_token_id"], [
|
||||
("This text ends with EOS token", "</s>", 2),
|
||||
])
|
||||
@pytest.mark.parametrize("ignore_eos", [True, False, None])
|
||||
@pytest.mark.parametrize("include_stop_str_in_output", [True, False, None])
|
||||
@pytest.mark.parametrize("ignore_eos", [True, False])
|
||||
@pytest.mark.parametrize("include_stop_str_in_output", [True, False])
|
||||
@pytest.mark.skip_global_cleanup
|
||||
def test_stop_on_eos_token(text_wo_eos: str, eos_token: str, eos_token_id: int,
|
||||
ignore_eos: bool, include_stop_str_in_output: bool):
|
||||
|
|
|
@ -32,7 +32,10 @@ async def _async_serving_chat_init():
|
|||
model_config,
|
||||
served_model_names=[MODEL_NAME],
|
||||
response_role="assistant",
|
||||
chat_template=CHAT_TEMPLATE)
|
||||
chat_template=CHAT_TEMPLATE,
|
||||
lora_modules=None,
|
||||
prompt_adapters=None,
|
||||
request_logger=None)
|
||||
return serving_completion
|
||||
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
|
|||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster
|
||||
from vllm.inputs import PromptStrictInputs, TextPrompt, TokensPrompt
|
||||
from vllm.inputs import PromptInputs, TextPrompt, TokensPrompt
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.outputs import (CompletionOutput, EmbeddingOutput,
|
||||
EmbeddingRequestOutput, RequestOutput)
|
||||
|
@ -19,7 +19,7 @@ __all__ = [
|
|||
"__version__",
|
||||
"LLM",
|
||||
"ModelRegistry",
|
||||
"PromptStrictInputs",
|
||||
"PromptInputs",
|
||||
"TextPrompt",
|
||||
"TokensPrompt",
|
||||
"SamplingParams",
|
||||
|
|
|
@ -827,7 +827,6 @@ class AsyncEngineArgs(EngineArgs):
|
|||
"""Arguments for asynchronous vLLM engine."""
|
||||
engine_use_ray: bool = False
|
||||
disable_log_requests: bool = False
|
||||
max_log_len: Optional[int] = None
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: FlexibleArgumentParser,
|
||||
|
@ -841,12 +840,6 @@ class AsyncEngineArgs(EngineArgs):
|
|||
parser.add_argument('--disable-log-requests',
|
||||
action='store_true',
|
||||
help='Disable logging requests.')
|
||||
parser.add_argument('--max-log-len',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Max number of prompt characters or prompt '
|
||||
'ID numbers being printed in log.'
|
||||
'\n\nDefault: Unlimited')
|
||||
return parser
|
||||
|
||||
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
import asyncio
|
||||
import time
|
||||
from functools import partial
|
||||
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Optional,
|
||||
Set, Tuple, Type, Union)
|
||||
from typing import (AsyncIterator, Callable, Dict, Iterable, List, Mapping,
|
||||
Optional, Set, Tuple, Type, Union)
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
|
@ -151,7 +151,10 @@ class RequestTracker:
|
|||
logger.info("Finished request %s.", request_id)
|
||||
self.abort_request(request_id)
|
||||
|
||||
def add_request(self, request_id: str,
|
||||
def add_request(self,
|
||||
request_id: str,
|
||||
*,
|
||||
verbose: bool = False,
|
||||
**engine_add_request_kwargs) -> AsyncStream:
|
||||
"""Add a request to be sent to the engine on the next background
|
||||
loop iteration."""
|
||||
|
@ -166,6 +169,9 @@ class RequestTracker:
|
|||
|
||||
self.new_requests_event.set()
|
||||
|
||||
if verbose:
|
||||
logger.info("Added request %s.", request_id)
|
||||
|
||||
return stream
|
||||
|
||||
def abort_request(self, request_id: str, *, verbose: bool = False) -> None:
|
||||
|
@ -299,14 +305,14 @@ class _AsyncLLMEngine(LLMEngine):
|
|||
return self.input_processor(llm_inputs)
|
||||
|
||||
async def add_request_async(
|
||||
self,
|
||||
request_id: str,
|
||||
inputs: PromptInputs,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Dict[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
self,
|
||||
request_id: str,
|
||||
inputs: PromptInputs,
|
||||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> None:
|
||||
if lora_request is not None and not self.lora_config:
|
||||
raise ValueError(f"Got lora_request {lora_request} but LoRA is "
|
||||
|
@ -353,8 +359,6 @@ class AsyncLLMEngine:
|
|||
async frontend will be executed in a separate process as the
|
||||
model workers.
|
||||
log_requests: Whether to log the requests.
|
||||
max_log_len: Maximum number of prompt characters or prompt ID numbers
|
||||
being printed in log.
|
||||
start_engine_loop: If True, the background task to run the engine
|
||||
will be automatically started in the generate call.
|
||||
*args: Arguments for :class:`LLMEngine`.
|
||||
|
@ -368,13 +372,11 @@ class AsyncLLMEngine:
|
|||
engine_use_ray: bool,
|
||||
*args,
|
||||
log_requests: bool = True,
|
||||
max_log_len: Optional[int] = None,
|
||||
start_engine_loop: bool = True,
|
||||
**kwargs) -> None:
|
||||
self.worker_use_ray = worker_use_ray
|
||||
self.engine_use_ray = engine_use_ray
|
||||
self.log_requests = log_requests
|
||||
self.max_log_len = max_log_len
|
||||
self.engine = self._init_engine(*args, **kwargs)
|
||||
|
||||
self.background_loop: Optional[asyncio.Future] = None
|
||||
|
@ -468,7 +470,6 @@ class AsyncLLMEngine:
|
|||
executor_class=executor_class,
|
||||
log_requests=not engine_args.disable_log_requests,
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
max_log_len=engine_args.max_log_len,
|
||||
start_engine_loop=start_engine_loop,
|
||||
usage_context=usage_context,
|
||||
stat_loggers=stat_loggers,
|
||||
|
@ -667,30 +668,9 @@ class AsyncLLMEngine:
|
|||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Dict[str, str]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
) -> AsyncStream:
|
||||
if self.log_requests:
|
||||
if isinstance(inputs, str):
|
||||
shortened_prompt = inputs
|
||||
shortened_token_ids = None
|
||||
else:
|
||||
shortened_prompt = inputs.get("prompt")
|
||||
shortened_token_ids = inputs.get("prompt_token_ids")
|
||||
|
||||
max_log_len = self.max_log_len
|
||||
if max_log_len is not None:
|
||||
if shortened_prompt is not None:
|
||||
shortened_prompt = shortened_prompt[:max_log_len]
|
||||
if shortened_token_ids is not None:
|
||||
shortened_token_ids = shortened_token_ids[:max_log_len]
|
||||
|
||||
logger.info(
|
||||
"Received request %s: prompt: %r, "
|
||||
"params: %s, prompt_token_ids: %s, "
|
||||
"lora_request: %s.", request_id, shortened_prompt, params,
|
||||
shortened_token_ids, lora_request)
|
||||
|
||||
if not self.is_running:
|
||||
if self.start_engine_loop:
|
||||
self.start_background_loop()
|
||||
|
@ -706,6 +686,7 @@ class AsyncLLMEngine:
|
|||
|
||||
stream = self._request_tracker.add_request(
|
||||
request_id,
|
||||
verbose=self.log_requests,
|
||||
inputs=inputs,
|
||||
params=params,
|
||||
arrival_time=arrival_time,
|
||||
|
@ -721,7 +702,7 @@ class AsyncLLMEngine:
|
|||
sampling_params: SamplingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Dict[str, str]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
) -> AsyncIterator[RequestOutput]:
|
||||
"""Generate outputs for a request.
|
||||
|
@ -804,7 +785,7 @@ class AsyncLLMEngine:
|
|||
pooling_params: PoolingParams,
|
||||
request_id: str,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Dict[str, str]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
) -> AsyncIterator[EmbeddingRequestOutput]:
|
||||
"""Generate outputs for a request from an embedding model.
|
||||
|
||||
|
@ -882,7 +863,7 @@ class AsyncLLMEngine:
|
|||
params: Union[SamplingParams, PoolingParams],
|
||||
*,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Dict[str, str]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]:
|
||||
"""Common logic to process requests with SamplingParams or
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, Optional
|
||||
from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List,
|
||||
Mapping, Optional)
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Set, Type, TypeVar, Union
|
||||
|
||||
|
@ -522,7 +523,7 @@ class LLMEngine:
|
|||
arrival_time: float,
|
||||
lora_request: Optional[LoRARequest],
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
trace_headers: Optional[Dict[str, str]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
) -> None:
|
||||
# Create the sequences.
|
||||
block_size = self.cache_config.block_size
|
||||
|
@ -603,7 +604,7 @@ class LLMEngine:
|
|||
params: Union[SamplingParams, PoolingParams],
|
||||
arrival_time: Optional[float] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
trace_headers: Optional[Dict[str, str]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> None:
|
||||
"""Add a request to the engine's request pool.
|
||||
|
@ -677,7 +678,7 @@ class LLMEngine:
|
|||
sampling_params: SamplingParams,
|
||||
arrival_time: float,
|
||||
lora_request: Optional[LoRARequest],
|
||||
trace_headers: Optional[Dict[str, str]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> SequenceGroup:
|
||||
"""Creates a SequenceGroup with SamplingParams."""
|
||||
|
|
|
@ -6,8 +6,7 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
|||
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.inputs import (PromptInputs, PromptStrictInputs, TextPrompt,
|
||||
TextTokensPrompt, TokensPrompt,
|
||||
from vllm.inputs import (PromptInputs, TextPrompt, TokensPrompt,
|
||||
parse_and_batch_prompt)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
|
@ -238,7 +237,7 @@ class LLM:
|
|||
@overload
|
||||
def generate(
|
||||
self,
|
||||
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
|
||||
inputs: Union[PromptInputs, Sequence[PromptInputs]],
|
||||
/, # We may enable `inputs` keyword after removing the old API
|
||||
*,
|
||||
sampling_params: Optional[Union[SamplingParams,
|
||||
|
@ -255,7 +254,7 @@ class LLM:
|
|||
"instead.")
|
||||
def generate(
|
||||
self,
|
||||
prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
|
||||
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
|
||||
Optional[Union[str, List[str]]]] = None,
|
||||
sampling_params: Optional[Union[SamplingParams,
|
||||
Sequence[SamplingParams]]] = None,
|
||||
|
@ -302,9 +301,7 @@ class LLM:
|
|||
prompt_token_ids=prompt_token_ids,
|
||||
)
|
||||
else:
|
||||
inputs = cast(
|
||||
Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
|
||||
prompts)
|
||||
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
|
||||
|
||||
if sampling_params is None:
|
||||
# Use default sampling params.
|
||||
|
@ -383,7 +380,7 @@ class LLM:
|
|||
@overload
|
||||
def encode(
|
||||
self,
|
||||
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
|
||||
inputs: Union[PromptInputs, Sequence[PromptInputs]],
|
||||
/, # We may enable `inputs` keyword after removing the old API
|
||||
*,
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
|
@ -400,7 +397,7 @@ class LLM:
|
|||
"instead.")
|
||||
def encode(
|
||||
self,
|
||||
prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
|
||||
prompts: Union[Union[PromptInputs, Sequence[PromptInputs]],
|
||||
Optional[Union[str, List[str]]]] = None,
|
||||
pooling_params: Optional[Union[PoolingParams,
|
||||
Sequence[PoolingParams]]] = None,
|
||||
|
@ -417,7 +414,7 @@ class LLM:
|
|||
|
||||
Args:
|
||||
inputs: The inputs to the LLM. You may pass a sequence of inputs for
|
||||
batch inference. See :class:`~vllm.inputs.PromptStrictInputs`
|
||||
batch inference. See :class:`~vllm.inputs.PromptInputs`
|
||||
for more details about the format of each input.
|
||||
pooling_params: The pooling parameters for pooling. If None, we
|
||||
use the default pooling parameters.
|
||||
|
@ -446,9 +443,7 @@ class LLM:
|
|||
prompt_token_ids=prompt_token_ids,
|
||||
)
|
||||
else:
|
||||
inputs = cast(
|
||||
Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
|
||||
prompts)
|
||||
inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], prompts)
|
||||
|
||||
if pooling_params is None:
|
||||
# Use default pooling params.
|
||||
|
@ -496,17 +491,11 @@ class LLM:
|
|||
inputs: List[PromptInputs] = []
|
||||
for i in range(num_requests):
|
||||
if prompts is not None:
|
||||
if prompt_token_ids is not None:
|
||||
item = TextTokensPrompt(
|
||||
prompt=prompts[i],
|
||||
prompt_token_ids=prompt_token_ids[i])
|
||||
else:
|
||||
item = TextPrompt(prompt=prompts[i])
|
||||
item = TextPrompt(prompt=prompts[i])
|
||||
elif prompt_token_ids is not None:
|
||||
item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
|
||||
else:
|
||||
if prompt_token_ids is not None:
|
||||
item = TokensPrompt(prompt_token_ids=prompt_token_ids[i])
|
||||
else:
|
||||
raise AssertionError
|
||||
raise AssertionError
|
||||
|
||||
inputs.append(item)
|
||||
|
||||
|
@ -514,7 +503,7 @@ class LLM:
|
|||
|
||||
def _validate_and_add_requests(
|
||||
self,
|
||||
inputs: Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
|
||||
inputs: Union[PromptInputs, Sequence[PromptInputs]],
|
||||
params: Union[SamplingParams, Sequence[SamplingParams], PoolingParams,
|
||||
Sequence[PoolingParams]],
|
||||
lora_request: Optional[Union[Sequence[LoRARequest], LoRARequest]],
|
||||
|
|
|
@ -0,0 +1,41 @@
|
|||
from typing import List, Optional, Union
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class RequestLogger:
|
||||
|
||||
def __init__(self, *, max_log_len: Optional[int]) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.max_log_len = max_log_len
|
||||
|
||||
def log_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
prompt: Optional[str],
|
||||
prompt_token_ids: Optional[List[int]],
|
||||
params: Optional[Union[SamplingParams, PoolingParams]],
|
||||
lora_request: Optional[LoRARequest],
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
) -> None:
|
||||
max_log_len = self.max_log_len
|
||||
if max_log_len is not None:
|
||||
if prompt is not None:
|
||||
prompt = prompt[:max_log_len]
|
||||
|
||||
if prompt_token_ids is not None:
|
||||
prompt_token_ids = prompt_token_ids[:max_log_len]
|
||||
|
||||
logger.info(
|
||||
"Received request %s: prompt: %r, "
|
||||
"params: %s, prompt_token_ids: %s, "
|
||||
"lora_request: %s, prompt_adapter_request: %s.", request_id,
|
||||
prompt, params, prompt_token_ids, lora_request,
|
||||
prompt_adapter_request)
|
|
@ -18,6 +18,7 @@ from starlette.routing import Mount
|
|||
import vllm.envs as envs
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
|
@ -244,24 +245,48 @@ def run_server(args, llm_engine=None):
|
|||
# When using single vLLM without engine_use_ray
|
||||
model_config = asyncio.run(engine.get_model_config())
|
||||
|
||||
if args.disable_log_requests:
|
||||
request_logger = None
|
||||
else:
|
||||
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||
|
||||
global openai_serving_chat
|
||||
global openai_serving_completion
|
||||
global openai_serving_embedding
|
||||
global openai_serving_tokenization
|
||||
|
||||
openai_serving_chat = OpenAIServingChat(engine, model_config,
|
||||
served_model_names,
|
||||
args.response_role,
|
||||
args.lora_modules,
|
||||
args.chat_template)
|
||||
openai_serving_chat = OpenAIServingChat(
|
||||
engine,
|
||||
model_config,
|
||||
served_model_names,
|
||||
args.response_role,
|
||||
lora_modules=args.lora_modules,
|
||||
prompt_adapters=args.prompt_adapters,
|
||||
request_logger=request_logger,
|
||||
chat_template=args.chat_template,
|
||||
)
|
||||
openai_serving_completion = OpenAIServingCompletion(
|
||||
engine, model_config, served_model_names, args.lora_modules,
|
||||
args.prompt_adapters)
|
||||
openai_serving_embedding = OpenAIServingEmbedding(engine, model_config,
|
||||
served_model_names)
|
||||
engine,
|
||||
model_config,
|
||||
served_model_names,
|
||||
lora_modules=args.lora_modules,
|
||||
prompt_adapters=args.prompt_adapters,
|
||||
request_logger=request_logger,
|
||||
)
|
||||
openai_serving_embedding = OpenAIServingEmbedding(
|
||||
engine,
|
||||
model_config,
|
||||
served_model_names,
|
||||
request_logger=request_logger,
|
||||
)
|
||||
openai_serving_tokenization = OpenAIServingTokenization(
|
||||
engine, model_config, served_model_names, args.lora_modules,
|
||||
args.chat_template)
|
||||
engine,
|
||||
model_config,
|
||||
served_model_names,
|
||||
lora_modules=args.lora_modules,
|
||||
request_logger=request_logger,
|
||||
chat_template=args.chat_template,
|
||||
)
|
||||
app.root_path = args.root_path
|
||||
|
||||
logger.info("Available routes are:")
|
||||
|
|
|
@ -130,6 +130,14 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
|||
"using app.add_middleware(). ")
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
parser.add_argument('--max-log-len',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Max number of prompt characters or prompt '
|
||||
'ID numbers being printed in log.'
|
||||
'\n\nDefault: Unlimited')
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
|
|
|
@ -121,40 +121,42 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||
|
||||
# doc: begin-chat-completion-sampling-params
|
||||
best_of: Optional[int] = None
|
||||
use_beam_search: Optional[bool] = False
|
||||
top_k: Optional[int] = -1
|
||||
min_p: Optional[float] = 0.0
|
||||
repetition_penalty: Optional[float] = 1.0
|
||||
length_penalty: Optional[float] = 1.0
|
||||
early_stopping: Optional[bool] = False
|
||||
ignore_eos: Optional[bool] = False
|
||||
min_tokens: Optional[int] = 0
|
||||
use_beam_search: bool = False
|
||||
top_k: int = -1
|
||||
min_p: float = 0.0
|
||||
repetition_penalty: float = 1.0
|
||||
length_penalty: float = 1.0
|
||||
early_stopping: bool = False
|
||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||
skip_special_tokens: Optional[bool] = True
|
||||
spaces_between_special_tokens: Optional[bool] = True
|
||||
include_stop_str_in_output: bool = False
|
||||
ignore_eos: bool = False
|
||||
min_tokens: int = 0
|
||||
skip_special_tokens: bool = True
|
||||
spaces_between_special_tokens: bool = True
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
# doc: end-chat-completion-sampling-params
|
||||
|
||||
# doc: begin-chat-completion-extra-params
|
||||
echo: Optional[bool] = Field(
|
||||
echo: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If true, the new message will be prepended with the last message "
|
||||
"if they belong to the same role."),
|
||||
)
|
||||
add_generation_prompt: Optional[bool] = Field(
|
||||
add_generation_prompt: bool = Field(
|
||||
default=True,
|
||||
description=
|
||||
("If true, the generation prompt will be added to the chat template. "
|
||||
"This is a parameter used by chat template in tokenizer config of the "
|
||||
"model."),
|
||||
)
|
||||
add_special_tokens: Optional[bool] = Field(
|
||||
add_special_tokens: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If true, special tokens (e.g. BOS) will be added to the prompt "
|
||||
"on top of what is added by the chat template. "
|
||||
"For most models, the chat template takes care of adding the "
|
||||
"special tokens so this should be set to False (as is the "
|
||||
"special tokens so this should be set to false (as is the "
|
||||
"default)."),
|
||||
)
|
||||
documents: Optional[List[Dict[str, str]]] = Field(
|
||||
|
@ -178,12 +180,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||
description=("Additional kwargs to pass to the template renderer. "
|
||||
"Will be accessible by the chat template."),
|
||||
)
|
||||
include_stop_str_in_output: Optional[bool] = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"Whether to include the stop string in the output. "
|
||||
"This is only applied when the stop or stop_token_ids is set."),
|
||||
)
|
||||
guided_json: Optional[Union[str, dict, BaseModel]] = Field(
|
||||
default=None,
|
||||
description=("If specified, the output will follow the JSON schema."),
|
||||
|
@ -244,22 +240,22 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||
|
||||
return SamplingParams(
|
||||
n=self.n,
|
||||
best_of=self.best_of,
|
||||
presence_penalty=self.presence_penalty,
|
||||
frequency_penalty=self.frequency_penalty,
|
||||
repetition_penalty=self.repetition_penalty,
|
||||
temperature=self.temperature,
|
||||
top_p=self.top_p,
|
||||
top_k=self.top_k,
|
||||
min_p=self.min_p,
|
||||
seed=self.seed,
|
||||
stop=self.stop,
|
||||
stop_token_ids=self.stop_token_ids,
|
||||
max_tokens=self.max_tokens,
|
||||
min_tokens=self.min_tokens,
|
||||
logprobs=self.top_logprobs if self.logprobs else None,
|
||||
prompt_logprobs=self.top_logprobs if self.echo else None,
|
||||
best_of=self.best_of,
|
||||
top_k=self.top_k,
|
||||
ignore_eos=self.ignore_eos,
|
||||
max_tokens=self.max_tokens,
|
||||
min_tokens=self.min_tokens,
|
||||
use_beam_search=self.use_beam_search,
|
||||
early_stopping=self.early_stopping,
|
||||
skip_special_tokens=self.skip_special_tokens,
|
||||
|
@ -267,6 +263,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
|||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
length_penalty=self.length_penalty,
|
||||
logits_processors=logits_processors,
|
||||
truncate_prompt_tokens=self.truncate_prompt_tokens,
|
||||
)
|
||||
|
||||
@model_validator(mode='before')
|
||||
|
@ -348,26 +345,27 @@ class CompletionRequest(OpenAIBaseModel):
|
|||
user: Optional[str] = None
|
||||
|
||||
# doc: begin-completion-sampling-params
|
||||
use_beam_search: Optional[bool] = False
|
||||
top_k: Optional[int] = -1
|
||||
min_p: Optional[float] = 0.0
|
||||
repetition_penalty: Optional[float] = 1.0
|
||||
length_penalty: Optional[float] = 1.0
|
||||
early_stopping: Optional[bool] = False
|
||||
use_beam_search: bool = False
|
||||
top_k: int = -1
|
||||
min_p: float = 0.0
|
||||
repetition_penalty: float = 1.0
|
||||
length_penalty: float = 1.0
|
||||
early_stopping: bool = False
|
||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||
ignore_eos: Optional[bool] = False
|
||||
min_tokens: Optional[int] = 0
|
||||
skip_special_tokens: Optional[bool] = True
|
||||
spaces_between_special_tokens: Optional[bool] = True
|
||||
include_stop_str_in_output: bool = False
|
||||
ignore_eos: bool = False
|
||||
min_tokens: int = 0
|
||||
skip_special_tokens: bool = True
|
||||
spaces_between_special_tokens: bool = True
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None
|
||||
# doc: end-completion-sampling-params
|
||||
|
||||
# doc: begin-completion-extra-params
|
||||
include_stop_str_in_output: Optional[bool] = Field(
|
||||
default=False,
|
||||
add_special_tokens: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"Whether to include the stop string in the output. "
|
||||
"This is only applied when the stop or stop_token_ids is set."),
|
||||
"If true (the default), special tokens (e.g. BOS) will be added to "
|
||||
"the prompt."),
|
||||
)
|
||||
response_format: Optional[ResponseFormat] = Field(
|
||||
default=None,
|
||||
|
@ -447,15 +445,15 @@ class CompletionRequest(OpenAIBaseModel):
|
|||
seed=self.seed,
|
||||
stop=self.stop,
|
||||
stop_token_ids=self.stop_token_ids,
|
||||
logprobs=self.logprobs,
|
||||
ignore_eos=self.ignore_eos,
|
||||
max_tokens=self.max_tokens if not echo_without_generation else 1,
|
||||
min_tokens=self.min_tokens,
|
||||
logprobs=self.logprobs,
|
||||
use_beam_search=self.use_beam_search,
|
||||
early_stopping=self.early_stopping,
|
||||
prompt_logprobs=self.logprobs if self.echo else None,
|
||||
skip_special_tokens=self.skip_special_tokens,
|
||||
spaces_between_special_tokens=(self.spaces_between_special_tokens),
|
||||
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
||||
include_stop_str_in_output=self.include_stop_str_in_output,
|
||||
length_penalty=self.length_penalty,
|
||||
logits_processors=logits_processors,
|
||||
|
@ -489,11 +487,11 @@ class CompletionRequest(OpenAIBaseModel):
|
|||
def validate_stream_options(cls, data):
|
||||
if data.get("stream_options") and not data.get("stream"):
|
||||
raise ValueError(
|
||||
"Stream options can only be defined when stream is True.")
|
||||
"Stream options can only be defined when stream is true.")
|
||||
return data
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
class EmbeddingRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/embeddings
|
||||
model: str
|
||||
|
@ -565,13 +563,13 @@ class CompletionStreamResponse(OpenAIBaseModel):
|
|||
usage: Optional[UsageInfo] = Field(default=None)
|
||||
|
||||
|
||||
class EmbeddingResponseData(BaseModel):
|
||||
class EmbeddingResponseData(OpenAIBaseModel):
|
||||
index: int
|
||||
object: str = "embedding"
|
||||
embedding: Union[List[float], str]
|
||||
|
||||
|
||||
class EmbeddingResponse(BaseModel):
|
||||
class EmbeddingResponse(OpenAIBaseModel):
|
||||
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
||||
object: str = "list"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
|
@ -670,8 +668,8 @@ class BatchRequestInput(OpenAIBaseModel):
|
|||
# /v1/chat/completions is supported.
|
||||
url: str
|
||||
|
||||
# The parameteters of the request.
|
||||
body: Union[ChatCompletionRequest, ]
|
||||
# The parameters of the request.
|
||||
body: ChatCompletionRequest
|
||||
|
||||
|
||||
class BatchResponseData(OpenAIBaseModel):
|
||||
|
@ -703,12 +701,22 @@ class BatchRequestOutput(OpenAIBaseModel):
|
|||
error: Optional[Any]
|
||||
|
||||
|
||||
class TokenizeRequest(OpenAIBaseModel):
|
||||
class TokenizeCompletionRequest(OpenAIBaseModel):
|
||||
model: str
|
||||
prompt: str
|
||||
|
||||
add_special_tokens: bool = Field(default=True)
|
||||
|
||||
|
||||
class TokenizeChatRequest(OpenAIBaseModel):
|
||||
model: str
|
||||
messages: List[ChatCompletionMessageParam]
|
||||
|
||||
add_generation_prompt: bool = Field(default=True)
|
||||
add_special_tokens: bool = Field(default=False)
|
||||
prompt: Optional[str] = Field(default=None)
|
||||
messages: Optional[List[ChatCompletionMessageParam]] = Field(default=None)
|
||||
model: str
|
||||
|
||||
|
||||
TokenizeRequest = Union[TokenizeCompletionRequest, TokenizeChatRequest]
|
||||
|
||||
|
||||
class TokenizeResponse(OpenAIBaseModel):
|
||||
|
|
|
@ -6,6 +6,7 @@ import aiohttp
|
|||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (BatchRequestInput,
|
||||
BatchRequestOutput,
|
||||
BatchResponseData,
|
||||
|
@ -44,9 +45,17 @@ def parse_args():
|
|||
type=nullable_str,
|
||||
default="assistant",
|
||||
help="The role name to return if "
|
||||
"`request.add_generation_prompt=true`.")
|
||||
"`request.add_generation_prompt=True`.")
|
||||
|
||||
parser = AsyncEngineArgs.add_cli_args(parser)
|
||||
|
||||
parser.add_argument('--max-log-len',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Max number of prompt characters or prompt '
|
||||
'ID numbers being printed in log.'
|
||||
'\n\nDefault: Unlimited')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
@ -114,11 +123,20 @@ async def main(args):
|
|||
# When using single vLLM without engine_use_ray
|
||||
model_config = await engine.get_model_config()
|
||||
|
||||
if args.disable_log_requests:
|
||||
request_logger = None
|
||||
else:
|
||||
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||
|
||||
openai_serving_chat = OpenAIServingChat(
|
||||
engine,
|
||||
model_config,
|
||||
served_model_names,
|
||||
args.response_role,
|
||||
lora_modules=None,
|
||||
prompt_adapters=None,
|
||||
request_logger=request_logger,
|
||||
chat_template=None,
|
||||
)
|
||||
|
||||
# Submit all requests in the file to the engine "concurrently".
|
||||
|
|
|
@ -12,6 +12,7 @@ from vllm.engine.async_llm_engine import AsyncLLMEngine
|
|||
from vllm.entrypoints.chat_utils import (ConversationMessage,
|
||||
load_chat_template,
|
||||
parse_chat_message_content)
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
ChatCompletionLogProb, ChatCompletionLogProbs,
|
||||
ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam,
|
||||
|
@ -20,7 +21,8 @@ from vllm.entrypoints.openai.protocol import (
|
|||
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
||||
FunctionCall, ToolCall, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||
OpenAIServing)
|
||||
OpenAIServing,
|
||||
PromptAdapterPath)
|
||||
from vllm.inputs import PromptInputs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.guided_decoding import (
|
||||
|
@ -37,17 +39,24 @@ logger = init_logger(__name__)
|
|||
|
||||
class OpenAIServingChat(OpenAIServing):
|
||||
|
||||
def __init__(self,
|
||||
engine: AsyncLLMEngine,
|
||||
model_config: ModelConfig,
|
||||
served_model_names: List[str],
|
||||
response_role: str,
|
||||
lora_modules: Optional[List[LoRAModulePath]] = None,
|
||||
chat_template: Optional[str] = None):
|
||||
def __init__(
|
||||
self,
|
||||
engine: AsyncLLMEngine,
|
||||
model_config: ModelConfig,
|
||||
served_model_names: List[str],
|
||||
response_role: str,
|
||||
*,
|
||||
lora_modules: Optional[List[LoRAModulePath]],
|
||||
prompt_adapters: Optional[List[PromptAdapterPath]],
|
||||
request_logger: Optional[RequestLogger],
|
||||
chat_template: Optional[str],
|
||||
):
|
||||
super().__init__(engine=engine,
|
||||
model_config=model_config,
|
||||
served_model_names=served_model_names,
|
||||
lora_modules=lora_modules)
|
||||
lora_modules=lora_modules,
|
||||
prompt_adapters=prompt_adapters,
|
||||
request_logger=request_logger)
|
||||
|
||||
self.response_role = response_role
|
||||
|
||||
|
@ -74,7 +83,12 @@ class OpenAIServingChat(OpenAIServing):
|
|||
return error_check_ret
|
||||
|
||||
try:
|
||||
_, lora_request = self._maybe_get_adapter(request)
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
model_config = self.model_config
|
||||
tokenizer = await self.engine.get_tokenizer(lora_request)
|
||||
|
||||
conversation: List[ConversationMessage] = []
|
||||
|
@ -82,7 +96,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||
|
||||
for msg in request.messages:
|
||||
chat_parsed_result = parse_chat_message_content(
|
||||
msg, self.model_config, tokenizer)
|
||||
msg, model_config, tokenizer)
|
||||
|
||||
conversation.extend(chat_parsed_result.messages)
|
||||
mm_futures.extend(chat_parsed_result.mm_futures)
|
||||
|
@ -116,14 +130,8 @@ class OpenAIServingChat(OpenAIServing):
|
|||
logger.error("Error in loading multi-modal data: %s", e)
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
request_id = f"cmpl-{random_uuid()}"
|
||||
request_id = f"chat-{random_uuid()}"
|
||||
try:
|
||||
# Tokenize/detokenize depending on prompt format (string/token list)
|
||||
prompt_ids, prompt_text = await self._validate_prompt_and_tokenize(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt=prompt,
|
||||
add_special_tokens=request.add_special_tokens)
|
||||
sampling_params = request.to_sampling_params()
|
||||
decoding_config = await self.engine.get_decoding_config()
|
||||
guided_decoding_backend = request.guided_decoding_backend \
|
||||
|
@ -137,31 +145,47 @@ class OpenAIServingChat(OpenAIServing):
|
|||
sampling_params.logits_processors = []
|
||||
sampling_params.logits_processors.append(
|
||||
guided_decode_logits_processor)
|
||||
|
||||
prompt_inputs = self._tokenize_prompt_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt,
|
||||
truncate_prompt_tokens=sampling_params.truncate_prompt_tokens,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
|
||||
self._log_inputs(request_id,
|
||||
prompt_inputs,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
engine_inputs: PromptInputs = {
|
||||
"prompt_token_ids": prompt_inputs["prompt_token_ids"],
|
||||
}
|
||||
if mm_data is not None:
|
||||
engine_inputs["multi_modal_data"] = mm_data
|
||||
|
||||
is_tracing_enabled = await self.engine.is_tracing_enabled()
|
||||
trace_headers = None
|
||||
if is_tracing_enabled and raw_request:
|
||||
trace_headers = extract_trace_headers(raw_request.headers)
|
||||
if (not is_tracing_enabled and raw_request
|
||||
and contains_trace_headers(raw_request.headers)):
|
||||
log_tracing_disabled_warning()
|
||||
|
||||
result_generator = self.engine.generate(
|
||||
engine_inputs,
|
||||
sampling_params,
|
||||
request_id,
|
||||
lora_request=lora_request,
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
||||
inputs: PromptInputs = {
|
||||
"prompt": prompt_text,
|
||||
"prompt_token_ids": prompt_ids,
|
||||
}
|
||||
if mm_data:
|
||||
inputs["multi_modal_data"] = mm_data
|
||||
|
||||
is_tracing_enabled = await self.engine.is_tracing_enabled()
|
||||
trace_headers = None
|
||||
if is_tracing_enabled and raw_request:
|
||||
trace_headers = extract_trace_headers(raw_request.headers)
|
||||
if not is_tracing_enabled and raw_request and contains_trace_headers(
|
||||
raw_request.headers):
|
||||
log_tracing_disabled_warning()
|
||||
|
||||
result_generator = self.engine.generate(
|
||||
inputs,
|
||||
sampling_params,
|
||||
request_id,
|
||||
lora_request,
|
||||
trace_headers=trace_headers,
|
||||
)
|
||||
# Streaming response
|
||||
if request.stream:
|
||||
return self.chat_completion_stream_generator(
|
||||
|
@ -195,10 +219,11 @@ class OpenAIServingChat(OpenAIServing):
|
|||
first_iteration = True
|
||||
|
||||
# Send response for each token for each request.n (index)
|
||||
assert request.n is not None
|
||||
previous_texts = [""] * request.n
|
||||
previous_num_tokens = [0] * request.n
|
||||
finish_reason_sent = [False] * request.n
|
||||
num_choices = 1 if request.n is None else request.n
|
||||
previous_texts = [""] * num_choices
|
||||
previous_num_tokens = [0] * num_choices
|
||||
finish_reason_sent = [False] * num_choices
|
||||
|
||||
try:
|
||||
async for res in result_generator:
|
||||
# We need to do it here, because if there are exceptions in
|
||||
|
@ -208,7 +233,7 @@ class OpenAIServingChat(OpenAIServing):
|
|||
# Send first response for each request.n (index) with
|
||||
# the role
|
||||
role = self.get_chat_request_role(request)
|
||||
for i in range(request.n):
|
||||
for i in range(num_choices):
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(role=role),
|
||||
|
@ -236,19 +261,19 @@ class OpenAIServingChat(OpenAIServing):
|
|||
last_msg_content = conversation[-1]["content"]
|
||||
|
||||
if last_msg_content:
|
||||
for i in range(request.n):
|
||||
for i in range(num_choices):
|
||||
choice_data = (
|
||||
ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(
|
||||
content=last_msg_content),
|
||||
logprobs=None,
|
||||
finish_reason=None))
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=request_id,
|
||||
object=chunk_object_type,
|
||||
created=created_time,
|
||||
choices=[choice_data],
|
||||
logprobs=None,
|
||||
model=model_name)
|
||||
if (request.stream_options and
|
||||
request.stream_options.include_usage):
|
||||
|
|
|
@ -2,13 +2,14 @@ import time
|
|||
from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List,
|
||||
Optional)
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Tuple
|
||||
from typing import Tuple, cast
|
||||
|
||||
from fastapi import Request
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
||||
|
@ -39,40 +40,24 @@ TypeCreateLogProbsFn = Callable[
|
|||
[TypeTokenIDs, TypeTopLogProbs, Optional[int], int], CompletionLogProbs]
|
||||
|
||||
|
||||
def parse_prompt_format(prompt) -> Tuple[bool, list]:
|
||||
# get the prompt, openai supports the following
|
||||
# "a string, array of strings, array of tokens, or array of token arrays."
|
||||
prompt_is_tokens = False
|
||||
prompts = [prompt] # case 1: a string
|
||||
if isinstance(prompt, list):
|
||||
if len(prompt) == 0:
|
||||
raise ValueError("please provide at least one prompt")
|
||||
elif isinstance(prompt[0], str):
|
||||
prompt_is_tokens = False
|
||||
prompts = prompt # case 2: array of strings
|
||||
elif isinstance(prompt[0], int):
|
||||
prompt_is_tokens = True
|
||||
prompts = [prompt] # case 3: array of tokens
|
||||
elif isinstance(prompt[0], list) and isinstance(prompt[0][0], int):
|
||||
prompt_is_tokens = True
|
||||
prompts = prompt # case 4: array of token arrays
|
||||
else:
|
||||
raise ValueError("prompt must be a string, array of strings, "
|
||||
"array of tokens, or array of token arrays")
|
||||
return prompt_is_tokens, prompts
|
||||
|
||||
|
||||
class OpenAIServingCompletion(OpenAIServing):
|
||||
|
||||
def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
|
||||
served_model_names: List[str],
|
||||
lora_modules: Optional[List[LoRAModulePath]],
|
||||
prompt_adapters: Optional[List[PromptAdapterPath]]):
|
||||
def __init__(
|
||||
self,
|
||||
engine: AsyncLLMEngine,
|
||||
model_config: ModelConfig,
|
||||
served_model_names: List[str],
|
||||
*,
|
||||
lora_modules: Optional[List[LoRAModulePath]],
|
||||
prompt_adapters: Optional[List[PromptAdapterPath]],
|
||||
request_logger: Optional[RequestLogger],
|
||||
):
|
||||
super().__init__(engine=engine,
|
||||
model_config=model_config,
|
||||
served_model_names=served_model_names,
|
||||
lora_modules=lora_modules,
|
||||
prompt_adapters=prompt_adapters)
|
||||
prompt_adapters=prompt_adapters,
|
||||
request_logger=request_logger)
|
||||
|
||||
async def create_completion(self, request: CompletionRequest,
|
||||
raw_request: Request):
|
||||
|
@ -101,12 +86,11 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||
# Schedule the request and get the result generator.
|
||||
generators: List[AsyncIterator[RequestOutput]] = []
|
||||
try:
|
||||
adapter_type, adapter_request = self._maybe_get_adapter(request)
|
||||
lora_request, prompt_adapter_request = None, None
|
||||
if adapter_type == 'LoRA':
|
||||
lora_request, prompt_adapter_request = adapter_request, None
|
||||
elif adapter_type == 'PromptAdapter':
|
||||
lora_request, prompt_adapter_request = None, adapter_request
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine.get_tokenizer(lora_request)
|
||||
|
||||
sampling_params = request.to_sampling_params()
|
||||
|
@ -122,17 +106,25 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||
sampling_params.logits_processors = []
|
||||
sampling_params.logits_processors.append(
|
||||
guided_decode_logit_processor)
|
||||
prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
|
||||
|
||||
for i, prompt in enumerate(prompts):
|
||||
prompt_arg = "prompt_ids" if prompt_is_tokens else "prompt"
|
||||
prompt_formats = await self._validate_prompt_and_tokenize(
|
||||
prompts = list(
|
||||
self._tokenize_prompt_input_or_inputs(
|
||||
request,
|
||||
tokenizer,
|
||||
request.prompt,
|
||||
truncate_prompt_tokens=sampling_params.
|
||||
truncate_prompt_tokens,
|
||||
**{prompt_arg: prompt})
|
||||
prompt_ids, prompt_text = prompt_formats
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
))
|
||||
|
||||
for i, prompt_inputs in enumerate(prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(request_id_item,
|
||||
prompt_inputs,
|
||||
params=sampling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
is_tracing_enabled = await self.engine.is_tracing_enabled()
|
||||
trace_headers = None
|
||||
|
@ -143,12 +135,9 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||
log_tracing_disabled_warning()
|
||||
|
||||
generator = self.engine.generate(
|
||||
{
|
||||
"prompt": prompt_text,
|
||||
"prompt_token_ids": prompt_ids
|
||||
},
|
||||
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
|
||||
sampling_params,
|
||||
f"{request_id}-{i}",
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
trace_headers=trace_headers,
|
||||
|
@ -189,9 +178,27 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||
await self.engine.abort(f"{request_id}-{i}")
|
||||
return self.create_error_response("Client disconnected")
|
||||
final_res_batch[i] = res
|
||||
|
||||
for i, final_res in enumerate(final_res_batch):
|
||||
assert final_res is not None
|
||||
|
||||
# The output should contain the input text
|
||||
# We did not pass it into vLLM engine to avoid being redundant
|
||||
# with the inputs token IDs
|
||||
if final_res.prompt is None:
|
||||
final_res.prompt = prompts[i]["prompt"]
|
||||
|
||||
final_res_batch_checked = cast(List[RequestOutput],
|
||||
final_res_batch)
|
||||
|
||||
response = self.request_output_to_completion_response(
|
||||
final_res_batch, request, request_id, created_time, model_name,
|
||||
tokenizer)
|
||||
final_res_batch_checked,
|
||||
request,
|
||||
request_id,
|
||||
created_time,
|
||||
model_name,
|
||||
tokenizer,
|
||||
)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response(str(e))
|
||||
|
@ -220,10 +227,10 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||
num_prompts: int,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
assert request.n is not None
|
||||
previous_texts = [""] * request.n * num_prompts
|
||||
previous_num_tokens = [0] * request.n * num_prompts
|
||||
has_echoed = [False] * request.n * num_prompts
|
||||
num_choices = 1 if request.n is None else request.n
|
||||
previous_texts = [""] * num_choices * num_prompts
|
||||
previous_num_tokens = [0] * num_choices * num_prompts
|
||||
has_echoed = [False] * num_choices * num_prompts
|
||||
|
||||
try:
|
||||
async for prompt_idx, res in result_generator:
|
||||
|
@ -234,7 +241,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||
raise StopAsyncIteration()
|
||||
|
||||
for output in res.outputs:
|
||||
i = output.index + prompt_idx * request.n
|
||||
i = output.index + prompt_idx * num_choices
|
||||
# TODO(simon): optimize the performance by avoiding full
|
||||
# text O(n^2) sending.
|
||||
|
||||
|
@ -343,8 +350,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||
choices: List[CompletionResponseChoice] = []
|
||||
num_prompt_tokens = 0
|
||||
num_generated_tokens = 0
|
||||
|
||||
for final_res in final_res_batch:
|
||||
assert final_res is not None
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
prompt_logprobs = final_res.prompt_logprobs
|
||||
prompt_text = final_res.prompt
|
||||
|
|
|
@ -1,16 +1,16 @@
|
|||
import base64
|
||||
import time
|
||||
from typing import AsyncIterator, List, Optional, Tuple
|
||||
from typing import AsyncIterator, List, Optional, Tuple, cast
|
||||
|
||||
import numpy as np
|
||||
from fastapi import Request
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
EmbeddingResponseData, UsageInfo)
|
||||
from vllm.entrypoints.openai.serving_completion import parse_prompt_format
|
||||
from vllm.entrypoints.openai.serving_engine import OpenAIServing
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import EmbeddingRequestOutput
|
||||
|
@ -28,11 +28,11 @@ def request_output_to_embedding_response(
|
|||
data: List[EmbeddingResponseData] = []
|
||||
num_prompt_tokens = 0
|
||||
for idx, final_res in enumerate(final_res_batch):
|
||||
assert final_res is not None
|
||||
prompt_token_ids = final_res.prompt_token_ids
|
||||
embedding = final_res.outputs.embedding
|
||||
if encoding_format == "base64":
|
||||
embedding = base64.b64encode(np.array(embedding))
|
||||
embedding_bytes = np.array(embedding).tobytes()
|
||||
embedding = base64.b64encode(embedding_bytes).decode("utf-8")
|
||||
embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
|
||||
data.append(embedding_data)
|
||||
|
||||
|
@ -54,12 +54,20 @@ def request_output_to_embedding_response(
|
|||
|
||||
class OpenAIServingEmbedding(OpenAIServing):
|
||||
|
||||
def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
|
||||
served_model_names: List[str]):
|
||||
def __init__(
|
||||
self,
|
||||
engine: AsyncLLMEngine,
|
||||
model_config: ModelConfig,
|
||||
served_model_names: List[str],
|
||||
*,
|
||||
request_logger: Optional[RequestLogger],
|
||||
):
|
||||
super().__init__(engine=engine,
|
||||
model_config=model_config,
|
||||
served_model_names=served_model_names,
|
||||
lora_modules=None)
|
||||
lora_modules=None,
|
||||
prompt_adapters=None,
|
||||
request_logger=request_logger)
|
||||
self._check_embedding_mode(model_config.embedding_mode)
|
||||
|
||||
async def create_embedding(self, request: EmbeddingRequest,
|
||||
|
@ -80,29 +88,47 @@ class OpenAIServingEmbedding(OpenAIServing):
|
|||
"dimensions is currently not supported")
|
||||
|
||||
model_name = request.model
|
||||
request_id = f"cmpl-{random_uuid()}"
|
||||
request_id = f"embd-{random_uuid()}"
|
||||
created_time = int(time.monotonic())
|
||||
|
||||
# Schedule the request and get the result generator.
|
||||
generators = []
|
||||
generators: List[AsyncIterator[EmbeddingRequestOutput]] = []
|
||||
try:
|
||||
prompt_is_tokens, prompts = parse_prompt_format(request.input)
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine.get_tokenizer(lora_request)
|
||||
|
||||
pooling_params = request.to_pooling_params()
|
||||
|
||||
tokenizer = await self.engine.get_tokenizer()
|
||||
for i, prompt in enumerate(prompts):
|
||||
prompt_arg = "prompt_ids" if prompt_is_tokens else "prompt"
|
||||
prompt_formats = await self._validate_prompt_and_tokenize(
|
||||
request, tokenizer, **{prompt_arg: prompt})
|
||||
prompt_ids, prompt_text = prompt_formats
|
||||
prompts = list(
|
||||
self._tokenize_prompt_input_or_inputs(
|
||||
request,
|
||||
tokenizer,
|
||||
request.input,
|
||||
))
|
||||
|
||||
for i, prompt_inputs in enumerate(prompts):
|
||||
request_id_item = f"{request_id}-{i}"
|
||||
|
||||
self._log_inputs(request_id_item,
|
||||
prompt_inputs,
|
||||
params=pooling_params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
if prompt_adapter_request is not None:
|
||||
raise NotImplementedError(
|
||||
"Prompt adapter is not supported "
|
||||
"for embedding models")
|
||||
|
||||
generator = self.engine.encode(
|
||||
{
|
||||
"prompt": prompt_text,
|
||||
"prompt_token_ids": prompt_ids
|
||||
},
|
||||
{"prompt_token_ids": prompt_inputs["prompt_token_ids"]},
|
||||
pooling_params,
|
||||
f"{request_id}-{i}",
|
||||
request_id_item,
|
||||
lora_request=lora_request,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
|
@ -121,11 +147,17 @@ class OpenAIServingEmbedding(OpenAIServing):
|
|||
if await raw_request.is_disconnected():
|
||||
# Abort the request if the client disconnects.
|
||||
await self.engine.abort(f"{request_id}-{i}")
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
return self.create_error_response("Client disconnected")
|
||||
final_res_batch[i] = res
|
||||
|
||||
for final_res in final_res_batch:
|
||||
assert final_res is not None
|
||||
|
||||
final_res_batch_checked = cast(List[EmbeddingRequestOutput],
|
||||
final_res_batch)
|
||||
|
||||
response = request_output_to_embedding_response(
|
||||
final_res_batch, request_id, created_time, model_name,
|
||||
final_res_batch_checked, request_id, created_time, model_name,
|
||||
encoding_format)
|
||||
except ValueError as e:
|
||||
# TODO: Use a vllm-specific Validation Error
|
||||
|
|
|
@ -2,23 +2,33 @@ import json
|
|||
import pathlib
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union
|
||||
|
||||
from pydantic import Field
|
||||
from transformers import PreTrainedTokenizer
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
|
||||
CompletionRequest,
|
||||
DetokenizeRequest,
|
||||
EmbeddingRequest, ErrorResponse,
|
||||
ModelCard, ModelList,
|
||||
ModelPermission, TokenizeRequest)
|
||||
ModelPermission,
|
||||
TokenizeChatRequest,
|
||||
TokenizeCompletionRequest,
|
||||
TokenizeRequest)
|
||||
# yapf: enable
|
||||
from vllm.inputs import parse_and_batch_prompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import Logprob
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
@ -36,6 +46,17 @@ class LoRAModulePath:
|
|||
local_path: str
|
||||
|
||||
|
||||
AnyRequest = Union[ChatCompletionRequest, CompletionRequest, DetokenizeRequest,
|
||||
EmbeddingRequest, TokenizeRequest]
|
||||
|
||||
AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
|
||||
|
||||
|
||||
class TextTokensPrompt(TypedDict):
|
||||
prompt: str
|
||||
prompt_token_ids: List[int]
|
||||
|
||||
|
||||
class OpenAIServing:
|
||||
|
||||
def __init__(
|
||||
|
@ -43,8 +64,10 @@ class OpenAIServing:
|
|||
engine: AsyncLLMEngine,
|
||||
model_config: ModelConfig,
|
||||
served_model_names: List[str],
|
||||
*,
|
||||
lora_modules: Optional[List[LoRAModulePath]],
|
||||
prompt_adapters: Optional[List[PromptAdapterPath]] = None,
|
||||
prompt_adapters: Optional[List[PromptAdapterPath]],
|
||||
request_logger: Optional[RequestLogger],
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
@ -78,6 +101,8 @@ class OpenAIServing:
|
|||
prompt_adapter_local_path=prompt_adapter.local_path,
|
||||
prompt_adapter_num_virtual_tokens=num_virtual_tokens))
|
||||
|
||||
self.request_logger = request_logger
|
||||
|
||||
async def show_available_models(self) -> ModelList:
|
||||
"""Show available models. Right now we only have one model."""
|
||||
model_cards = [
|
||||
|
@ -126,9 +151,8 @@ class OpenAIServing:
|
|||
return json_str
|
||||
|
||||
async def _check_model(
|
||||
self, request: Union[ChatCompletionRequest, CompletionRequest,
|
||||
DetokenizeRequest, EmbeddingRequest,
|
||||
TokenizeRequest]
|
||||
self,
|
||||
request: AnyRequest,
|
||||
) -> Optional[ErrorResponse]:
|
||||
if request.model in self.served_model_names:
|
||||
return None
|
||||
|
@ -144,64 +168,65 @@ class OpenAIServing:
|
|||
err_type="NotFoundError",
|
||||
status_code=HTTPStatus.NOT_FOUND)
|
||||
|
||||
def _maybe_get_adapter(
|
||||
self, request: Union[CompletionRequest, ChatCompletionRequest,
|
||||
EmbeddingRequest, TokenizeRequest,
|
||||
DetokenizeRequest]
|
||||
) -> Tuple[Optional[str], Optional[Union[LoRARequest,
|
||||
PromptAdapterRequest]]]:
|
||||
def _maybe_get_adapters(
|
||||
self, request: AnyRequest
|
||||
) -> Union[Tuple[None, None], Tuple[LoRARequest, None], Tuple[
|
||||
None, PromptAdapterRequest]]:
|
||||
if request.model in self.served_model_names:
|
||||
return None, None
|
||||
for lora in self.lora_requests:
|
||||
if request.model == lora.lora_name:
|
||||
return 'LoRA', lora
|
||||
return lora, None
|
||||
for prompt_adapter in self.prompt_adapter_requests:
|
||||
if request.model == prompt_adapter.prompt_adapter_name:
|
||||
return 'PromptAdapter', prompt_adapter
|
||||
return None, prompt_adapter
|
||||
# if _check_model has been called earlier, this will be unreachable
|
||||
raise ValueError(f"The model `{request.model}` does not exist.")
|
||||
|
||||
async def _validate_prompt_and_tokenize(
|
||||
self,
|
||||
request: Union[ChatCompletionRequest, CompletionRequest,
|
||||
DetokenizeRequest, EmbeddingRequest,
|
||||
TokenizeRequest],
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
prompt: Optional[str] = None,
|
||||
prompt_ids: Optional[List[int]] = None,
|
||||
truncate_prompt_tokens: Optional[Annotated[int,
|
||||
Field(ge=1)]] = None,
|
||||
add_special_tokens: Optional[bool] = True
|
||||
) -> Tuple[List[int], str]:
|
||||
if not (prompt or prompt_ids):
|
||||
raise ValueError("Either prompt or prompt_ids should be provided.")
|
||||
if prompt and prompt_ids:
|
||||
raise ValueError(
|
||||
"Only one of prompt or prompt_ids should be provided.")
|
||||
|
||||
if prompt_ids is None:
|
||||
# When using OpenAIServingChat for chat completions, for
|
||||
# most models the special tokens (e.g., BOS) have already
|
||||
# been added by the chat template. Therefore, we do not
|
||||
# need to add them again.
|
||||
# Set add_special_tokens to False (by default) to avoid
|
||||
# adding the BOS tokens again.
|
||||
tokenizer_kwargs: Dict[str, Any] = {
|
||||
"add_special_tokens": add_special_tokens
|
||||
}
|
||||
if truncate_prompt_tokens is not None:
|
||||
tokenizer_kwargs.update({
|
||||
"truncation": True,
|
||||
"max_length": truncate_prompt_tokens,
|
||||
})
|
||||
input_ids = tokenizer(prompt, **tokenizer_kwargs).input_ids
|
||||
elif truncate_prompt_tokens is not None:
|
||||
input_ids = prompt_ids[-truncate_prompt_tokens:]
|
||||
def _normalize_prompt_text_to_input(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt: str,
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
|
||||
add_special_tokens: bool,
|
||||
) -> TextTokensPrompt:
|
||||
if truncate_prompt_tokens is None:
|
||||
encoded = tokenizer(prompt, add_special_tokens=add_special_tokens)
|
||||
else:
|
||||
input_ids = prompt_ids
|
||||
encoded = tokenizer(prompt,
|
||||
add_special_tokens=add_special_tokens,
|
||||
truncation=True,
|
||||
max_length=truncate_prompt_tokens)
|
||||
|
||||
input_text = prompt if prompt is not None else tokenizer.decode(
|
||||
input_ids)
|
||||
input_ids = encoded.input_ids
|
||||
|
||||
input_text = prompt
|
||||
|
||||
return self._validate_input(request, input_ids, input_text)
|
||||
|
||||
def _normalize_prompt_tokens_to_input(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt_ids: List[int],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]],
|
||||
) -> TextTokensPrompt:
|
||||
if truncate_prompt_tokens is None:
|
||||
input_ids = prompt_ids
|
||||
else:
|
||||
input_ids = prompt_ids[-truncate_prompt_tokens:]
|
||||
|
||||
input_text = tokenizer.decode(input_ids)
|
||||
|
||||
return self._validate_input(request, input_ids, input_text)
|
||||
|
||||
def _validate_input(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
input_ids: List[int],
|
||||
input_text: str,
|
||||
) -> TextTokensPrompt:
|
||||
token_num = len(input_ids)
|
||||
|
||||
# Note: EmbeddingRequest doesn't have max_tokens
|
||||
|
@ -211,13 +236,16 @@ class OpenAIServing:
|
|||
f"This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens. However, you requested "
|
||||
f"{token_num} tokens in the input for embedding "
|
||||
f"generation. Please reduce the length of the input.", )
|
||||
return input_ids, input_text
|
||||
f"generation. Please reduce the length of the input.")
|
||||
return TextTokensPrompt(prompt=input_text,
|
||||
prompt_token_ids=input_ids)
|
||||
|
||||
# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
|
||||
# and does not require model context length validation
|
||||
if isinstance(request, (TokenizeRequest, DetokenizeRequest)):
|
||||
return input_ids, input_text
|
||||
if isinstance(request, (TokenizeCompletionRequest, TokenizeChatRequest,
|
||||
DetokenizeRequest)):
|
||||
return TextTokensPrompt(prompt=input_text,
|
||||
prompt_token_ids=input_ids)
|
||||
|
||||
if request.max_tokens is None:
|
||||
if token_num >= self.max_model_len:
|
||||
|
@ -225,7 +253,7 @@ class OpenAIServing:
|
|||
f"This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens. However, you requested "
|
||||
f"{token_num} tokens in the messages, "
|
||||
f"Please reduce the length of the messages.", )
|
||||
f"Please reduce the length of the messages.")
|
||||
request.max_tokens = self.max_model_len - token_num
|
||||
|
||||
if token_num + request.max_tokens > self.max_model_len:
|
||||
|
@ -235,13 +263,132 @@ class OpenAIServing:
|
|||
f"{request.max_tokens + token_num} tokens "
|
||||
f"({token_num} in the messages, "
|
||||
f"{request.max_tokens} in the completion). "
|
||||
f"Please reduce the length of the messages or completion.", )
|
||||
f"Please reduce the length of the messages or completion.")
|
||||
|
||||
return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)
|
||||
|
||||
def _tokenize_prompt_input(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt_input: Union[str, List[int]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
) -> TextTokensPrompt:
|
||||
"""
|
||||
A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
|
||||
that assumes single input.
|
||||
"""
|
||||
return next(
|
||||
self._tokenize_prompt_inputs(
|
||||
request,
|
||||
tokenizer,
|
||||
[prompt_input],
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=add_special_tokens,
|
||||
))
|
||||
|
||||
def _tokenize_prompt_inputs(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
prompt_inputs: Iterable[Union[str, List[int]]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
) -> Iterator[TextTokensPrompt]:
|
||||
"""
|
||||
A simpler implementation of :meth:`_tokenize_prompt_input_or_inputs`
|
||||
that assumes multiple inputs.
|
||||
"""
|
||||
for text in prompt_inputs:
|
||||
if isinstance(text, str):
|
||||
yield self._normalize_prompt_text_to_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt=text,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
else:
|
||||
yield self._normalize_prompt_tokens_to_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt_ids=text,
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
)
|
||||
|
||||
def _tokenize_prompt_input_or_inputs(
|
||||
self,
|
||||
request: AnyRequest,
|
||||
tokenizer: AnyTokenizer,
|
||||
input_or_inputs: Union[str, List[str], List[int], List[List[int]]],
|
||||
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
) -> Iterator[TextTokensPrompt]:
|
||||
"""
|
||||
Tokenize/detokenize depending on the input format.
|
||||
|
||||
According to `OpenAI API <https://platform.openai.com/docs/api-reference/embeddings/create>`_
|
||||
, each input can be a string or array of tokens. Note that each request
|
||||
can pass one or more inputs.
|
||||
"""
|
||||
for prompt_input in parse_and_batch_prompt(input_or_inputs):
|
||||
# Although our type checking is based on mypy,
|
||||
# VSCode Pyright extension should still work properly
|
||||
# "is True" is required for Pyright to perform type narrowing
|
||||
# See: https://github.com/microsoft/pyright/issues/7672
|
||||
if prompt_input["is_tokens"] is False:
|
||||
yield self._normalize_prompt_text_to_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt=prompt_input["content"],
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
add_special_tokens=add_special_tokens,
|
||||
)
|
||||
else:
|
||||
yield self._normalize_prompt_tokens_to_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt_ids=prompt_input["content"],
|
||||
truncate_prompt_tokens=truncate_prompt_tokens,
|
||||
)
|
||||
|
||||
def _log_inputs(
|
||||
self,
|
||||
request_id: str,
|
||||
inputs: Union[str, List[int], TextTokensPrompt],
|
||||
params: Optional[Union[SamplingParams, PoolingParams]],
|
||||
lora_request: Optional[LoRARequest],
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
) -> None:
|
||||
if self.request_logger is None:
|
||||
return
|
||||
|
||||
if isinstance(inputs, str):
|
||||
prompt = inputs
|
||||
prompt_token_ids = None
|
||||
elif isinstance(inputs, list):
|
||||
prompt = None
|
||||
prompt_token_ids = inputs
|
||||
else:
|
||||
return input_ids, input_text
|
||||
prompt = inputs["prompt"]
|
||||
prompt_token_ids = inputs["prompt_token_ids"]
|
||||
|
||||
self.request_logger.log_inputs(
|
||||
request_id,
|
||||
prompt,
|
||||
prompt_token_ids,
|
||||
params=params,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_decoded_token(logprob: Logprob, token_id: int,
|
||||
tokenizer: PreTrainedTokenizer) -> str:
|
||||
def _get_decoded_token(
|
||||
logprob: Logprob,
|
||||
token_id: int,
|
||||
tokenizer: AnyTokenizer,
|
||||
) -> str:
|
||||
if logprob.decoded_token is not None:
|
||||
return logprob.decoded_token
|
||||
return tokenizer.decode(token_id)
|
||||
|
|
|
@ -1,83 +1,135 @@
|
|||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.entrypoints.chat_utils import (ConversationMessage,
|
||||
load_chat_template,
|
||||
parse_chat_message_content)
|
||||
from vllm.entrypoints.logger import RequestLogger
|
||||
from vllm.entrypoints.openai.protocol import (DetokenizeRequest,
|
||||
DetokenizeResponse,
|
||||
ErrorResponse,
|
||||
TokenizeChatRequest,
|
||||
TokenizeRequest,
|
||||
TokenizeResponse)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
|
||||
OpenAIServing)
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
|
||||
class OpenAIServingTokenization(OpenAIServing):
|
||||
|
||||
def __init__(self,
|
||||
engine: AsyncLLMEngine,
|
||||
model_config: ModelConfig,
|
||||
served_model_names: List[str],
|
||||
lora_modules: Optional[List[LoRAModulePath]] = None,
|
||||
chat_template: Optional[str] = None):
|
||||
def __init__(
|
||||
self,
|
||||
engine: AsyncLLMEngine,
|
||||
model_config: ModelConfig,
|
||||
served_model_names: List[str],
|
||||
*,
|
||||
lora_modules: Optional[List[LoRAModulePath]],
|
||||
request_logger: Optional[RequestLogger],
|
||||
chat_template: Optional[str],
|
||||
):
|
||||
super().__init__(engine=engine,
|
||||
model_config=model_config,
|
||||
served_model_names=served_model_names,
|
||||
lora_modules=lora_modules)
|
||||
lora_modules=lora_modules,
|
||||
prompt_adapters=None,
|
||||
request_logger=request_logger)
|
||||
|
||||
# If this is None we use the tokenizer's default chat template
|
||||
self.chat_template = load_chat_template(chat_template)
|
||||
|
||||
async def create_tokenize(self,
|
||||
request: TokenizeRequest) -> TokenizeResponse:
|
||||
async def create_tokenize(
|
||||
self,
|
||||
request: TokenizeRequest,
|
||||
) -> Union[TokenizeResponse, ErrorResponse]:
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
if not (request.prompt or request.messages):
|
||||
return self.create_error_response(
|
||||
"Either `prompt` or `messages` should be provided.")
|
||||
request_id = f"tokn-{random_uuid()}"
|
||||
|
||||
if (request.prompt and request.messages):
|
||||
return self.create_error_response(
|
||||
"Only one of `prompt` or `messages` should be provided.")
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
_, lora_request = self._maybe_get_adapter(request)
|
||||
tokenizer = await self.engine.get_tokenizer(lora_request)
|
||||
if request.messages:
|
||||
|
||||
if isinstance(request, TokenizeChatRequest):
|
||||
model_config = self.model_config
|
||||
|
||||
conversation: List[ConversationMessage] = []
|
||||
|
||||
for message in request.messages:
|
||||
result = parse_chat_message_content(message, self.model_config,
|
||||
result = parse_chat_message_content(message, model_config,
|
||||
tokenizer)
|
||||
conversation.extend(result.messages)
|
||||
|
||||
request.prompt = tokenizer.apply_chat_template(
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
add_generation_prompt=request.add_generation_prompt,
|
||||
conversation=conversation,
|
||||
tokenize=False,
|
||||
chat_template=self.chat_template)
|
||||
assert isinstance(prompt, str)
|
||||
else:
|
||||
prompt = request.prompt
|
||||
|
||||
(input_ids, input_text) = await self._validate_prompt_and_tokenize(
|
||||
self._log_inputs(request_id,
|
||||
prompt,
|
||||
params=None,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
# Silently ignore prompt adapter since it does not affect tokenization
|
||||
|
||||
prompt_input = self._tokenize_prompt_input(
|
||||
request,
|
||||
tokenizer,
|
||||
prompt=request.prompt,
|
||||
add_special_tokens=request.add_special_tokens)
|
||||
prompt,
|
||||
add_special_tokens=request.add_special_tokens,
|
||||
)
|
||||
input_ids = prompt_input["prompt_token_ids"]
|
||||
|
||||
return TokenizeResponse(tokens=input_ids,
|
||||
count=len(input_ids),
|
||||
max_model_len=self.max_model_len)
|
||||
|
||||
async def create_detokenize(
|
||||
self, request: DetokenizeRequest) -> DetokenizeResponse:
|
||||
self,
|
||||
request: DetokenizeRequest,
|
||||
) -> Union[DetokenizeResponse, ErrorResponse]:
|
||||
error_check_ret = await self._check_model(request)
|
||||
if error_check_ret is not None:
|
||||
return error_check_ret
|
||||
|
||||
_, lora_request = self._maybe_get_adapter(request)
|
||||
request_id = f"tokn-{random_uuid()}"
|
||||
|
||||
(
|
||||
lora_request,
|
||||
prompt_adapter_request,
|
||||
) = self._maybe_get_adapters(request)
|
||||
|
||||
tokenizer = await self.engine.get_tokenizer(lora_request)
|
||||
(input_ids, input_text) = await self._validate_prompt_and_tokenize(
|
||||
request, tokenizer, prompt_ids=request.tokens)
|
||||
|
||||
self._log_inputs(request_id,
|
||||
request.tokens,
|
||||
params=None,
|
||||
lora_request=lora_request,
|
||||
prompt_adapter_request=prompt_adapter_request)
|
||||
|
||||
if prompt_adapter_request is not None:
|
||||
raise NotImplementedError("Prompt adapter is not supported "
|
||||
"for tokenization")
|
||||
|
||||
prompt_input = self._tokenize_prompt_input(
|
||||
request,
|
||||
tokenizer,
|
||||
request.tokens,
|
||||
)
|
||||
input_text = prompt_input["prompt"]
|
||||
|
||||
return DetokenizeResponse(prompt=input_text)
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from .data import (LLMInputs, ParsedText, ParsedTokens, PromptInputs,
|
||||
PromptStrictInputs, TextPrompt, TextTokensPrompt,
|
||||
TokensPrompt, parse_and_batch_prompt)
|
||||
TextPrompt, TokensPrompt, parse_and_batch_prompt)
|
||||
from .registry import InputContext, InputRegistry
|
||||
|
||||
INPUT_REGISTRY = InputRegistry()
|
||||
|
@ -14,6 +13,6 @@ See also:
|
|||
|
||||
__all__ = [
|
||||
"ParsedText", "ParsedTokens", "parse_and_batch_prompt", "TextPrompt",
|
||||
"TokensPrompt", "TextTokensPrompt", "PromptStrictInputs", "PromptInputs",
|
||||
"LLMInputs", "INPUT_REGISTRY", "InputContext", "InputRegistry"
|
||||
"TokensPrompt", "PromptInputs", "LLMInputs", "INPUT_REGISTRY",
|
||||
"InputContext", "InputRegistry"
|
||||
]
|
||||
|
|
|
@ -92,25 +92,7 @@ class TokensPrompt(TypedDict):
|
|||
"""
|
||||
|
||||
|
||||
class TextTokensPrompt(TypedDict):
|
||||
"""It is assumed that :attr:`prompt` is consistent with
|
||||
:attr:`prompt_token_ids`. This is currently used in
|
||||
:class:`AsyncLLMEngine` for logging both the text and token IDs."""
|
||||
|
||||
prompt: str
|
||||
"""The prompt text."""
|
||||
|
||||
prompt_token_ids: List[int]
|
||||
"""The token IDs of the prompt."""
|
||||
|
||||
multi_modal_data: NotRequired["MultiModalDataDict"]
|
||||
"""
|
||||
Optional multi-modal data to pass to the model,
|
||||
if the model supports it.
|
||||
"""
|
||||
|
||||
|
||||
PromptStrictInputs = Union[str, TextPrompt, TokensPrompt]
|
||||
PromptInputs = Union[str, TextPrompt, TokensPrompt]
|
||||
"""
|
||||
The inputs to the LLM, which can take one of the following forms:
|
||||
|
||||
|
@ -118,10 +100,6 @@ The inputs to the LLM, which can take one of the following forms:
|
|||
- A tokenized prompt (:class:`TokensPrompt`)
|
||||
"""
|
||||
|
||||
PromptInputs = Union[str, TextPrompt, TokensPrompt, TextTokensPrompt]
|
||||
"""Same as :const:`PromptStrictInputs` but additionally accepts
|
||||
:class:`TextTokensPrompt`."""
|
||||
|
||||
|
||||
class LLMInputs(TypedDict):
|
||||
"""
|
||||
|
|
|
@ -5,7 +5,8 @@ import math
|
|||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple, Union
|
||||
from typing import (TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple,
|
||||
Union)
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -438,7 +439,7 @@ class SequenceGroup:
|
|||
embeddings: Optional[List[float]] = None,
|
||||
pooling_params: Optional[PoolingParams] = None,
|
||||
encoder_seq: Optional[Sequence] = None,
|
||||
trace_headers: Optional[Dict[str, str]] = None,
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
|
|
Loading…
Reference in New Issue