# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import json import os import shutil import tempfile import torch from huggingface_hub import snapshot_download from safetensors import safe_open from vllm import LLM, SamplingParams def patch_eagle_draft_with_lm_head(target_model_id: str, draft_model_id: str) -> str: # In NxDI, draft model checkpoint must include lm_head weights from target # model. For more details see https://awsdocs-neuron.readthedocs-hosted.com # /en/latest/libraries/nxd-inference/developer_guides/feature-guide.html # #eagle-checkpoint-compatibility final_draft_dir = "/tmp/patched_eagle_draft" with tempfile.TemporaryDirectory() as tmp_dir: target_dir = snapshot_download(repo_id=target_model_id, local_dir=os.path.join( tmp_dir, "target")) draft_dir = snapshot_download(repo_id=draft_model_id, local_dir=os.path.join(tmp_dir, "draft")) lm_head_key = "lm_head.weight" index_path = os.path.join(target_dir, "model.safetensors.index.json") with open(index_path) as f: index = json.load(f) shard_name = index["weight_map"][lm_head_key] target_safetensor_path = os.path.join(target_dir, shard_name) with safe_open(target_safetensor_path, framework="pt") as f: target_lm_head = f.get_tensor(lm_head_key) draft_path = os.path.join(draft_dir, "pytorch_model.bin") draft_state_dict = torch.load(draft_path, map_location="cpu") draft_state_dict[lm_head_key] = target_lm_head.to(torch.float16) torch.save(draft_state_dict, draft_path) shutil.copytree(draft_dir, final_draft_dir, dirs_exist_ok=True) return final_draft_dir def test_eagle(): patched_draft_path = patch_eagle_draft_with_lm_head( target_model_id="meta-llama/Llama-2-7b-hf", draft_model_id="yuhuili/EAGLE-llama2-chat-7B") llm = LLM( model="meta-llama/Llama-2-7b-hf", speculative_config={ "model": patched_draft_path, "num_speculative_tokens": 5, "max_model_len": 128 }, max_num_seqs=1, max_model_len=128, tensor_parallel_size=2, override_neuron_config={ "enable_eagle_speculation": True, "enable_fused_speculation": True, "fused_qkv": True }, ) prompts = [ "The president of the United States is", ] outputs = llm.generate(prompts, SamplingParams(top_k=1)) expected_output = " the head of state and head of government of " \ "the United States. The president direct" for output in outputs: generated_text = output.outputs[0].text print(f"Prompt: {output.prompt!r}, Generated text: {generated_text!r}") assert (expected_output == generated_text) print("Neuron Eagle speculation test passed.")