# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """ This file demonstrates the usage of text generation with an LLM model, comparing the performance with and without speculative decoding. Note that still not support `v1`: VLLM_USE_V1=0 python examples/offline_inference/mlpspeculator.py """ import gc import time from vllm import LLM, SamplingParams def time_generation( llm: LLM, prompts: list[str], sampling_params: SamplingParams, title: str ): # Generate texts from the prompts. The output is a list of RequestOutput # objects that contain the prompt, generated text, and other information. # Warmup first llm.generate(prompts, sampling_params) llm.generate(prompts, sampling_params) start = time.time() outputs = llm.generate(prompts, sampling_params) end = time.time() print("-" * 50) print(title) print("time: ", (end - start) / sum(len(o.outputs[0].token_ids) for o in outputs)) # Print the outputs. for output in outputs: generated_text = output.outputs[0].text print(f"text: {generated_text!r}") print("-" * 50) def main(): template = ( "Below is an instruction that describes a task. Write a response " "that appropriately completes the request.\n\n### Instruction:\n{}" "\n\n### Response:\n" ) # Sample prompts. prompts = [ "Write about the president of the United States.", ] prompts = [template.format(prompt) for prompt in prompts] # Create a sampling params object. sampling_params = SamplingParams(temperature=0.0, max_tokens=200) # Create an LLM without spec decoding llm = LLM(model="meta-llama/Llama-2-13b-chat-hf") time_generation(llm, prompts, sampling_params, "Without speculation") del llm gc.collect() # Create an LLM with spec decoding llm = LLM( model="meta-llama/Llama-2-13b-chat-hf", speculative_config={ "model": "ibm-ai-platform/llama-13b-accelerator", }, ) time_generation(llm, prompts, sampling_params, "With speculation") if __name__ == "__main__": main()