mirror of https://github.com/vllm-project/vllm.git
[Bugfix] Fix include prompt in stream response when echo=true (#15233)
Signed-off-by: Yuan Fang <yuanfang@alauda.io>
This commit is contained in:
parent
6d42ce8315
commit
e28533a16f
|
@ -779,3 +779,57 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI,
|
|||
prompt="Give an example string that fits this regex",
|
||||
extra_body=dict(guided_regex=sample_regex,
|
||||
guided_json=sample_json_schema))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name,stream,echo",
|
||||
[
|
||||
(MODEL_NAME, False, False),
|
||||
(MODEL_NAME, False, True),
|
||||
(MODEL_NAME, True, False),
|
||||
(MODEL_NAME, True, True) # should not raise BadRequestError error
|
||||
],
|
||||
)
|
||||
async def test_echo_stream_completion(client: openai.AsyncOpenAI,
|
||||
model_name: str, stream: bool,
|
||||
echo: bool):
|
||||
saying: str = "Hello, my name is"
|
||||
result = await client.completions.create(model=model_name,
|
||||
prompt=saying,
|
||||
max_tokens=10,
|
||||
temperature=0.0,
|
||||
echo=echo,
|
||||
stream=stream)
|
||||
|
||||
stop_reason = "length"
|
||||
|
||||
if not stream:
|
||||
completion = result
|
||||
assert completion.id is not None
|
||||
assert completion.choices is not None and len(completion.choices) == 1
|
||||
|
||||
choice = completion.choices[0]
|
||||
assert len(choice.text) >= 5
|
||||
assert choice.finish_reason == stop_reason
|
||||
|
||||
if echo:
|
||||
assert choice.text is not None and saying in choice.text
|
||||
else:
|
||||
assert choice.text is not None and saying not in choice.text
|
||||
|
||||
else:
|
||||
chunks: list[str] = []
|
||||
final_finish_reason = None
|
||||
async for chunk in result:
|
||||
if chunk.choices and chunk.choices[0].text:
|
||||
chunks.append(chunk.choices[0].text)
|
||||
if chunk.choices and chunk.choices[0].finish_reason:
|
||||
final_finish_reason = chunk.choices[0].finish_reason
|
||||
|
||||
assert final_finish_reason == stop_reason
|
||||
content = "".join(chunks)
|
||||
if echo:
|
||||
assert content is not None and saying in content
|
||||
else:
|
||||
assert content is not None and saying not in content
|
||||
|
|
|
@ -25,10 +25,13 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs,
|
|||
ErrorResponse,
|
||||
RequestResponseMetadata,
|
||||
UsageInfo)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_engine import (
|
||||
EmbedsPrompt as ServingEngineEmbedsPrompt)
|
||||
from vllm.entrypoints.openai.serving_engine import (OpenAIServing,
|
||||
TextTokensPrompt,
|
||||
clamp_prompt_logprobs,
|
||||
is_text_tokens_prompt)
|
||||
# yapf: enable
|
||||
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
|
||||
from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt,
|
||||
is_tokens_prompt)
|
||||
|
@ -223,6 +226,7 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||
if stream:
|
||||
return self.completion_stream_generator(
|
||||
request,
|
||||
request_prompts,
|
||||
result_generator,
|
||||
request_id,
|
||||
created_time,
|
||||
|
@ -285,6 +289,8 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||
async def completion_stream_generator(
|
||||
self,
|
||||
request: CompletionRequest,
|
||||
request_prompts: list[Union[TextTokensPrompt,
|
||||
ServingEngineEmbedsPrompt]],
|
||||
result_generator: AsyncIterator[tuple[int, RequestOutput]],
|
||||
request_id: str,
|
||||
created_time: int,
|
||||
|
@ -313,7 +319,15 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||
async for prompt_idx, res in result_generator:
|
||||
prompt_token_ids = res.prompt_token_ids
|
||||
prompt_logprobs = res.prompt_logprobs
|
||||
prompt_text = res.prompt
|
||||
|
||||
if res.prompt is not None:
|
||||
prompt_text = res.prompt
|
||||
else:
|
||||
request_prompt = request_prompts[prompt_idx]
|
||||
if is_text_tokens_prompt(request_prompt):
|
||||
prompt_text = request_prompt["prompt"]
|
||||
else:
|
||||
prompt_text = None
|
||||
|
||||
# Prompt details are excluded from later streamed outputs
|
||||
if prompt_token_ids is not None:
|
||||
|
@ -336,14 +350,13 @@ class OpenAIServingCompletion(OpenAIServing):
|
|||
delta_token_ids = prompt_token_ids
|
||||
out_logprobs = prompt_logprobs
|
||||
else:
|
||||
assert prompt_logprobs is not None
|
||||
# echo the prompt and first token
|
||||
delta_text = prompt_text + output.text
|
||||
delta_token_ids = [
|
||||
*prompt_token_ids, *output.token_ids
|
||||
]
|
||||
out_logprobs = [
|
||||
*prompt_logprobs,
|
||||
*(prompt_logprobs or []),
|
||||
*(output.logprobs or []),
|
||||
]
|
||||
has_echoed[i] = True
|
||||
|
|
Loading…
Reference in New Issue