mirror of https://github.com/vllm-project/vllm.git
84 lines
3.0 KiB
Python
84 lines
3.0 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
import json
|
|
import os
|
|
import shutil
|
|
import tempfile
|
|
|
|
import torch
|
|
from huggingface_hub import snapshot_download
|
|
from safetensors import safe_open
|
|
|
|
from vllm import LLM, SamplingParams
|
|
|
|
|
|
def patch_eagle_draft_with_lm_head(target_model_id: str,
|
|
draft_model_id: str) -> str:
|
|
# In NxDI, draft model checkpoint must include lm_head weights from target
|
|
# model. For more details see https://awsdocs-neuron.readthedocs-hosted.com
|
|
# /en/latest/libraries/nxd-inference/developer_guides/feature-guide.html
|
|
# #eagle-checkpoint-compatibility
|
|
final_draft_dir = "/tmp/patched_eagle_draft"
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
target_dir = snapshot_download(repo_id=target_model_id,
|
|
local_dir=os.path.join(
|
|
tmp_dir, "target"))
|
|
draft_dir = snapshot_download(repo_id=draft_model_id,
|
|
local_dir=os.path.join(tmp_dir, "draft"))
|
|
|
|
lm_head_key = "lm_head.weight"
|
|
index_path = os.path.join(target_dir, "model.safetensors.index.json")
|
|
with open(index_path) as f:
|
|
index = json.load(f)
|
|
shard_name = index["weight_map"][lm_head_key]
|
|
target_safetensor_path = os.path.join(target_dir, shard_name)
|
|
|
|
with safe_open(target_safetensor_path, framework="pt") as f:
|
|
target_lm_head = f.get_tensor(lm_head_key)
|
|
|
|
draft_path = os.path.join(draft_dir, "pytorch_model.bin")
|
|
draft_state_dict = torch.load(draft_path, map_location="cpu")
|
|
draft_state_dict[lm_head_key] = target_lm_head.to(torch.float16)
|
|
torch.save(draft_state_dict, draft_path)
|
|
|
|
shutil.copytree(draft_dir, final_draft_dir, dirs_exist_ok=True)
|
|
|
|
return final_draft_dir
|
|
|
|
|
|
def test_eagle():
|
|
patched_draft_path = patch_eagle_draft_with_lm_head(
|
|
target_model_id="meta-llama/Llama-2-7b-hf",
|
|
draft_model_id="yuhuili/EAGLE-llama2-chat-7B")
|
|
llm = LLM(
|
|
model="meta-llama/Llama-2-7b-hf",
|
|
speculative_config={
|
|
"model": patched_draft_path,
|
|
"num_speculative_tokens": 5,
|
|
"max_model_len": 128
|
|
},
|
|
max_num_seqs=1,
|
|
max_model_len=128,
|
|
tensor_parallel_size=2,
|
|
override_neuron_config={
|
|
"enable_eagle_speculation": True,
|
|
"enable_fused_speculation": True,
|
|
"fused_qkv": True
|
|
},
|
|
)
|
|
prompts = [
|
|
"The president of the United States is",
|
|
]
|
|
outputs = llm.generate(prompts, SamplingParams(top_k=1))
|
|
expected_output = " the head of state and head of government of " \
|
|
"the United States. The president direct"
|
|
|
|
for output in outputs:
|
|
generated_text = output.outputs[0].text
|
|
print(f"Prompt: {output.prompt!r}, Generated text: {generated_text!r}")
|
|
assert (expected_output == generated_text)
|
|
|
|
print("Neuron Eagle speculation test passed.")
|