diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index 7e54143f6e..7933ca5cd6 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -779,3 +779,57 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI, prompt="Give an example string that fits this regex", extra_body=dict(guided_regex=sample_regex, guided_json=sample_json_schema)) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name,stream,echo", + [ + (MODEL_NAME, False, False), + (MODEL_NAME, False, True), + (MODEL_NAME, True, False), + (MODEL_NAME, True, True) # should not raise BadRequestError error + ], +) +async def test_echo_stream_completion(client: openai.AsyncOpenAI, + model_name: str, stream: bool, + echo: bool): + saying: str = "Hello, my name is" + result = await client.completions.create(model=model_name, + prompt=saying, + max_tokens=10, + temperature=0.0, + echo=echo, + stream=stream) + + stop_reason = "length" + + if not stream: + completion = result + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 1 + + choice = completion.choices[0] + assert len(choice.text) >= 5 + assert choice.finish_reason == stop_reason + + if echo: + assert choice.text is not None and saying in choice.text + else: + assert choice.text is not None and saying not in choice.text + + else: + chunks: list[str] = [] + final_finish_reason = None + async for chunk in result: + if chunk.choices and chunk.choices[0].text: + chunks.append(chunk.choices[0].text) + if chunk.choices and chunk.choices[0].finish_reason: + final_finish_reason = chunk.choices[0].finish_reason + + assert final_finish_reason == stop_reason + content = "".join(chunks) + if echo: + assert content is not None and saying in content + else: + assert content is not None and saying not in content diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index a19fde8d70..8171b491aa 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -25,10 +25,13 @@ from vllm.entrypoints.openai.protocol import (CompletionLogProbs, ErrorResponse, RequestResponseMetadata, UsageInfo) -# yapf: enable +from vllm.entrypoints.openai.serving_engine import ( + EmbedsPrompt as ServingEngineEmbedsPrompt) from vllm.entrypoints.openai.serving_engine import (OpenAIServing, + TextTokensPrompt, clamp_prompt_logprobs, is_text_tokens_prompt) +# yapf: enable from vllm.entrypoints.openai.serving_models import OpenAIServingModels from vllm.inputs.data import (EmbedsPrompt, TokensPrompt, is_embeds_prompt, is_tokens_prompt) @@ -223,6 +226,7 @@ class OpenAIServingCompletion(OpenAIServing): if stream: return self.completion_stream_generator( request, + request_prompts, result_generator, request_id, created_time, @@ -285,6 +289,8 @@ class OpenAIServingCompletion(OpenAIServing): async def completion_stream_generator( self, request: CompletionRequest, + request_prompts: list[Union[TextTokensPrompt, + ServingEngineEmbedsPrompt]], result_generator: AsyncIterator[tuple[int, RequestOutput]], request_id: str, created_time: int, @@ -313,7 +319,15 @@ class OpenAIServingCompletion(OpenAIServing): async for prompt_idx, res in result_generator: prompt_token_ids = res.prompt_token_ids prompt_logprobs = res.prompt_logprobs - prompt_text = res.prompt + + if res.prompt is not None: + prompt_text = res.prompt + else: + request_prompt = request_prompts[prompt_idx] + if is_text_tokens_prompt(request_prompt): + prompt_text = request_prompt["prompt"] + else: + prompt_text = None # Prompt details are excluded from later streamed outputs if prompt_token_ids is not None: @@ -336,14 +350,13 @@ class OpenAIServingCompletion(OpenAIServing): delta_token_ids = prompt_token_ids out_logprobs = prompt_logprobs else: - assert prompt_logprobs is not None # echo the prompt and first token delta_text = prompt_text + output.text delta_token_ids = [ *prompt_token_ids, *output.token_ids ] out_logprobs = [ - *prompt_logprobs, + *(prompt_logprobs or []), *(output.logprobs or []), ] has_echoed[i] = True