mirror of https://github.com/vllm-project/vllm.git
[Frontend] Chat template fallbacks for multimodal models (#17805)
Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
parent
843b222723
commit
96722aa81d
|
@ -213,10 +213,13 @@ Our OpenAI-compatible server accepts multi-modal data via the [Chat Completions
|
||||||
|
|
||||||
:::{important}
|
:::{important}
|
||||||
A chat template is **required** to use Chat Completions API.
|
A chat template is **required** to use Chat Completions API.
|
||||||
|
For HF format models, the default chat template is defined inside `chat_template.json` or `tokenizer_config.json`.
|
||||||
|
|
||||||
Although most models come with a chat template, for others you have to define one yourself.
|
If no default chat template is available, we will first look for a built-in fallback in <gh-file:vllm/transformers_utils/chat_templates/registry.py>.
|
||||||
The chat template can be inferred based on the documentation on the model's HuggingFace repo.
|
If no fallback is available, an error is raised and you have to provide the chat template manually via the `--chat-template` argument.
|
||||||
For example, DeepSeek-VL2 requires a chat template that can be found here: <gh-file:examples/template_deepseek_vl2.jinja>
|
|
||||||
|
For certain models, we provide alternative chat templates inside <gh-dir:vllm/examples>.
|
||||||
|
For example, VLM2Vec uses <gh-file:examples/template_vlm2vec.jinja> which is different from the default one for Phi-3-Vision.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
### Image Inputs
|
### Image Inputs
|
||||||
|
|
|
@ -1,3 +0,0 @@
|
||||||
{%- for message in messages -%}
|
|
||||||
{{- message['content'] -}}
|
|
||||||
{%- endfor -%}
|
|
|
@ -1,3 +0,0 @@
|
||||||
{%- for message in messages -%}
|
|
||||||
{{- message['content'] -}}
|
|
||||||
{%- endfor -%}
|
|
|
@ -1,3 +0,0 @@
|
||||||
{%- for message in messages -%}
|
|
||||||
{{- message['content'] -}}
|
|
||||||
{%- endfor -%}
|
|
|
@ -2,11 +2,13 @@
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from vllm.config import ModelConfig
|
||||||
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
|
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
|
||||||
load_chat_template)
|
load_chat_template)
|
||||||
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
|
||||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
|
||||||
|
from ...models.registry import HF_EXAMPLE_MODELS
|
||||||
from ...utils import VLLM_PATH
|
from ...utils import VLLM_PATH
|
||||||
|
|
||||||
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
|
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
|
||||||
|
@ -91,8 +93,22 @@ def test_no_load_chat_template_literallike():
|
||||||
MODEL_TEMPLATE_GENERATON_OUTPUT)
|
MODEL_TEMPLATE_GENERATON_OUTPUT)
|
||||||
def test_get_gen_prompt(model, template, add_generation_prompt,
|
def test_get_gen_prompt(model, template, add_generation_prompt,
|
||||||
continue_final_message, expected_output):
|
continue_final_message, expected_output):
|
||||||
|
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||||
|
model_info.check_available_online(on_fail="skip")
|
||||||
|
|
||||||
|
model_config = ModelConfig(
|
||||||
|
model,
|
||||||
|
tokenizer=model_info.tokenizer or model,
|
||||||
|
tokenizer_mode=model_info.tokenizer_mode,
|
||||||
|
trust_remote_code=model_info.trust_remote_code,
|
||||||
|
hf_overrides=model_info.hf_overrides,
|
||||||
|
)
|
||||||
|
|
||||||
# Initialize the tokenizer
|
# Initialize the tokenizer
|
||||||
tokenizer = get_tokenizer(tokenizer_name=model)
|
tokenizer = get_tokenizer(
|
||||||
|
tokenizer_name=model_config.tokenizer,
|
||||||
|
trust_remote_code=model_config.trust_remote_code,
|
||||||
|
)
|
||||||
template_content = load_chat_template(chat_template=template)
|
template_content = load_chat_template(chat_template=template)
|
||||||
|
|
||||||
# Create a mock request object using keyword arguments
|
# Create a mock request object using keyword arguments
|
||||||
|
@ -106,8 +122,8 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
|
||||||
|
|
||||||
# Call the function and get the result
|
# Call the function and get the result
|
||||||
result = apply_hf_chat_template(
|
result = apply_hf_chat_template(
|
||||||
|
model_config,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
trust_remote_code=True,
|
|
||||||
conversation=mock_request.messages,
|
conversation=mock_request.messages,
|
||||||
chat_template=mock_request.chat_template or template_content,
|
chat_template=mock_request.chat_template or template_content,
|
||||||
tools=None,
|
tools=None,
|
||||||
|
|
|
@ -4,8 +4,6 @@ import warnings
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from packaging.version import Version
|
|
||||||
from transformers import __version__ as TRANSFORMERS_VERSION
|
|
||||||
|
|
||||||
from vllm.assets.image import ImageAsset
|
from vllm.assets.image import ImageAsset
|
||||||
from vllm.config import ModelConfig
|
from vllm.config import ModelConfig
|
||||||
|
@ -19,6 +17,7 @@ from vllm.multimodal import MultiModalDataDict
|
||||||
from vllm.multimodal.utils import encode_image_base64
|
from vllm.multimodal.utils import encode_image_base64
|
||||||
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
|
||||||
|
|
||||||
|
from ..models.registry import HF_EXAMPLE_MODELS
|
||||||
from ..utils import VLLM_PATH
|
from ..utils import VLLM_PATH
|
||||||
|
|
||||||
EXAMPLES_DIR = VLLM_PATH / "examples"
|
EXAMPLES_DIR = VLLM_PATH / "examples"
|
||||||
|
@ -772,6 +771,7 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
|
||||||
enable_lora=False,
|
enable_lora=False,
|
||||||
max_num_seqs=5,
|
max_num_seqs=5,
|
||||||
max_input_length=None,
|
max_input_length=None,
|
||||||
|
trust_remote_code=model_config.trust_remote_code,
|
||||||
)
|
)
|
||||||
tokenizer = tokenizer_group.tokenizer
|
tokenizer = tokenizer_group.tokenizer
|
||||||
|
|
||||||
|
@ -793,8 +793,8 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
|
||||||
)
|
)
|
||||||
|
|
||||||
vllm_result = apply_hf_chat_template(
|
vllm_result = apply_hf_chat_template(
|
||||||
|
model_config,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
trust_remote_code=model_config.trust_remote_code,
|
|
||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
chat_template=None,
|
chat_template=None,
|
||||||
tools=None,
|
tools=None,
|
||||||
|
@ -813,6 +813,16 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
|
||||||
@pytest.mark.parametrize("use_tools", [True, False])
|
@pytest.mark.parametrize("use_tools", [True, False])
|
||||||
def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
|
def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
|
||||||
"""checks that chat_template is a dict type for HF models."""
|
"""checks that chat_template is a dict type for HF models."""
|
||||||
|
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||||
|
model_info.check_available_online(on_fail="skip")
|
||||||
|
|
||||||
|
model_config = ModelConfig(
|
||||||
|
model,
|
||||||
|
tokenizer=model_info.tokenizer or model,
|
||||||
|
tokenizer_mode=model_info.tokenizer_mode,
|
||||||
|
trust_remote_code=model_info.trust_remote_code,
|
||||||
|
hf_overrides=model_info.hf_overrides,
|
||||||
|
)
|
||||||
|
|
||||||
# Build the tokenizer group and grab the underlying tokenizer
|
# Build the tokenizer group and grab the underlying tokenizer
|
||||||
tokenizer_group = TokenizerGroup(
|
tokenizer_group = TokenizerGroup(
|
||||||
|
@ -820,6 +830,7 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
|
||||||
enable_lora=False,
|
enable_lora=False,
|
||||||
max_num_seqs=5,
|
max_num_seqs=5,
|
||||||
max_input_length=None,
|
max_input_length=None,
|
||||||
|
trust_remote_code=model_config.trust_remote_code,
|
||||||
)
|
)
|
||||||
tokenizer = tokenizer_group.tokenizer
|
tokenizer = tokenizer_group.tokenizer
|
||||||
|
|
||||||
|
@ -834,10 +845,10 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
|
||||||
|
|
||||||
# Test detecting the tokenizer's chat_template
|
# Test detecting the tokenizer's chat_template
|
||||||
chat_template = resolve_hf_chat_template(
|
chat_template = resolve_hf_chat_template(
|
||||||
|
model_config,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
chat_template=None,
|
chat_template=None,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
trust_remote_code=True,
|
|
||||||
)
|
)
|
||||||
assert isinstance(chat_template, str)
|
assert isinstance(chat_template, str)
|
||||||
|
|
||||||
|
@ -857,24 +868,32 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
|
||||||
)
|
)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
def test_resolve_content_format_hf_defined(model, expected_format):
|
def test_resolve_content_format_hf_defined(model, expected_format):
|
||||||
if model == QWEN25VL_MODEL_ID and Version(TRANSFORMERS_VERSION) < Version(
|
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||||
"4.49.0"):
|
model_info.check_available_online(on_fail="skip")
|
||||||
pytest.skip("Qwen2.5-VL requires transformers>=4.49.0")
|
|
||||||
|
model_config = ModelConfig(
|
||||||
|
model,
|
||||||
|
tokenizer=model_info.tokenizer or model,
|
||||||
|
tokenizer_mode=model_info.tokenizer_mode,
|
||||||
|
trust_remote_code=model_info.trust_remote_code,
|
||||||
|
hf_overrides=model_info.hf_overrides,
|
||||||
|
)
|
||||||
|
|
||||||
tokenizer_group = TokenizerGroup(
|
tokenizer_group = TokenizerGroup(
|
||||||
model,
|
model,
|
||||||
enable_lora=False,
|
enable_lora=False,
|
||||||
max_num_seqs=5,
|
max_num_seqs=5,
|
||||||
max_input_length=None,
|
max_input_length=None,
|
||||||
|
trust_remote_code=model_config.trust_remote_code,
|
||||||
)
|
)
|
||||||
tokenizer = tokenizer_group.tokenizer
|
tokenizer = tokenizer_group.tokenizer
|
||||||
|
|
||||||
# Test detecting the tokenizer's chat_template
|
# Test detecting the tokenizer's chat_template
|
||||||
chat_template = resolve_hf_chat_template(
|
chat_template = resolve_hf_chat_template(
|
||||||
|
model_config,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
chat_template=None,
|
chat_template=None,
|
||||||
tools=None,
|
tools=None,
|
||||||
trust_remote_code=True,
|
|
||||||
)
|
)
|
||||||
assert isinstance(chat_template, str)
|
assert isinstance(chat_template, str)
|
||||||
|
|
||||||
|
@ -884,11 +903,70 @@ def test_resolve_content_format_hf_defined(model, expected_format):
|
||||||
print(_try_extract_ast(chat_template))
|
print(_try_extract_ast(chat_template))
|
||||||
|
|
||||||
resolved_format = resolve_chat_template_content_format(
|
resolved_format = resolve_chat_template_content_format(
|
||||||
|
model_config,
|
||||||
|
None, # Test detecting the tokenizer's chat_template
|
||||||
|
None,
|
||||||
|
"auto",
|
||||||
|
tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resolved_format == expected_format
|
||||||
|
|
||||||
|
|
||||||
|
# yapf: disable
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("model", "expected_format"),
|
||||||
|
[("Salesforce/blip2-opt-2.7b", "string"),
|
||||||
|
("facebook/chameleon-7b", "string"),
|
||||||
|
("deepseek-ai/deepseek-vl2-tiny", "string"),
|
||||||
|
("microsoft/Florence-2-base", "string"),
|
||||||
|
("adept/fuyu-8b", "string"),
|
||||||
|
("google/paligemma-3b-mix-224", "string"),
|
||||||
|
("Qwen/Qwen-VL", "string"),
|
||||||
|
("Qwen/Qwen-VL-Chat", "string")],
|
||||||
|
)
|
||||||
|
# yapf: enable
|
||||||
|
def test_resolve_content_format_fallbacks(model, expected_format):
|
||||||
|
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
|
||||||
|
model_info.check_available_online(on_fail="skip")
|
||||||
|
|
||||||
|
model_config = ModelConfig(
|
||||||
|
model,
|
||||||
|
tokenizer=model_info.tokenizer or model,
|
||||||
|
tokenizer_mode=model_info.tokenizer_mode,
|
||||||
|
trust_remote_code=model_info.trust_remote_code,
|
||||||
|
hf_overrides=model_info.hf_overrides,
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer_group = TokenizerGroup(
|
||||||
|
model_config.tokenizer,
|
||||||
|
enable_lora=False,
|
||||||
|
max_num_seqs=5,
|
||||||
|
max_input_length=None,
|
||||||
|
trust_remote_code=model_config.trust_remote_code,
|
||||||
|
)
|
||||||
|
tokenizer = tokenizer_group.tokenizer
|
||||||
|
|
||||||
|
# Test detecting the tokenizer's chat_template
|
||||||
|
chat_template = resolve_hf_chat_template(
|
||||||
|
model_config,
|
||||||
|
tokenizer,
|
||||||
|
chat_template=None,
|
||||||
|
tools=None,
|
||||||
|
)
|
||||||
|
assert isinstance(chat_template, str)
|
||||||
|
|
||||||
|
print("[TEXT]")
|
||||||
|
print(chat_template)
|
||||||
|
print("[AST]")
|
||||||
|
print(_try_extract_ast(chat_template))
|
||||||
|
|
||||||
|
resolved_format = resolve_chat_template_content_format(
|
||||||
|
model_config,
|
||||||
None, # Test detecting the tokenizer's chat_template
|
None, # Test detecting the tokenizer's chat_template
|
||||||
None,
|
None,
|
||||||
"auto",
|
"auto",
|
||||||
tokenizer,
|
tokenizer,
|
||||||
trust_remote_code=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert resolved_format == expected_format
|
assert resolved_format == expected_format
|
||||||
|
@ -899,22 +977,14 @@ def test_resolve_content_format_hf_defined(model, expected_format):
|
||||||
("template_path", "expected_format"),
|
("template_path", "expected_format"),
|
||||||
[("template_alpaca.jinja", "string"),
|
[("template_alpaca.jinja", "string"),
|
||||||
("template_baichuan.jinja", "string"),
|
("template_baichuan.jinja", "string"),
|
||||||
("template_blip2.jinja", "string"),
|
|
||||||
("template_chameleon.jinja", "string"),
|
|
||||||
("template_chatglm.jinja", "string"),
|
("template_chatglm.jinja", "string"),
|
||||||
("template_chatglm2.jinja", "string"),
|
("template_chatglm2.jinja", "string"),
|
||||||
("template_chatml.jinja", "string"),
|
("template_chatml.jinja", "string"),
|
||||||
("template_deepseek_vl2.jinja", "string"),
|
|
||||||
("template_dse_qwen2_vl.jinja", "openai"),
|
("template_dse_qwen2_vl.jinja", "openai"),
|
||||||
("template_falcon_180b.jinja", "string"),
|
("template_falcon_180b.jinja", "string"),
|
||||||
("template_falcon.jinja", "string"),
|
("template_falcon.jinja", "string"),
|
||||||
("template_florence2.jinja", "string"),
|
|
||||||
("template_fuyu.jinja", "string"),
|
|
||||||
("template_inkbot.jinja", "string"),
|
("template_inkbot.jinja", "string"),
|
||||||
("template_paligemma.jinja", "string"),
|
|
||||||
("template_teleflm.jinja", "string"),
|
("template_teleflm.jinja", "string"),
|
||||||
("template_qwen_vl.jinja", "string"),
|
|
||||||
("template_qwen_vl_chat.jinja", "string"),
|
|
||||||
("template_vlm2vec.jinja", "openai"),
|
("template_vlm2vec.jinja", "openai"),
|
||||||
("tool_chat_template_granite_20b_fc.jinja", "string"),
|
("tool_chat_template_granite_20b_fc.jinja", "string"),
|
||||||
("tool_chat_template_hermes.jinja", "string"),
|
("tool_chat_template_hermes.jinja", "string"),
|
||||||
|
@ -926,11 +996,18 @@ def test_resolve_content_format_hf_defined(model, expected_format):
|
||||||
)
|
)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
def test_resolve_content_format_examples(template_path, expected_format):
|
def test_resolve_content_format_examples(template_path, expected_format):
|
||||||
|
model_config = ModelConfig(
|
||||||
|
PHI3V_MODEL_ID, # Dummy
|
||||||
|
tokenizer=PHI3V_MODEL_ID, # Dummy
|
||||||
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
|
||||||
tokenizer_group = TokenizerGroup(
|
tokenizer_group = TokenizerGroup(
|
||||||
PHI3V_MODEL_ID,
|
PHI3V_MODEL_ID, # Dummy
|
||||||
enable_lora=False,
|
enable_lora=False,
|
||||||
max_num_seqs=5,
|
max_num_seqs=5,
|
||||||
max_input_length=None,
|
max_input_length=None,
|
||||||
|
trust_remote_code=model_config.trust_remote_code,
|
||||||
)
|
)
|
||||||
dummy_tokenizer = tokenizer_group.tokenizer
|
dummy_tokenizer = tokenizer_group.tokenizer
|
||||||
dummy_tokenizer.chat_template = None
|
dummy_tokenizer.chat_template = None
|
||||||
|
@ -944,11 +1021,11 @@ def test_resolve_content_format_examples(template_path, expected_format):
|
||||||
print(_try_extract_ast(chat_template))
|
print(_try_extract_ast(chat_template))
|
||||||
|
|
||||||
resolved_format = resolve_chat_template_content_format(
|
resolved_format = resolve_chat_template_content_format(
|
||||||
|
model_config,
|
||||||
chat_template,
|
chat_template,
|
||||||
None,
|
None,
|
||||||
"auto",
|
"auto",
|
||||||
dummy_tokenizer,
|
dummy_tokenizer,
|
||||||
trust_remote_code=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert resolved_format == expected_format
|
assert resolved_format == expected_format
|
||||||
|
|
|
@ -182,7 +182,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
|
||||||
"JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"),
|
"JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"),
|
||||||
"JambaForCausalLM": _HfExamplesInfo("ai21labs/AI21-Jamba-1.5-Mini",
|
"JambaForCausalLM": _HfExamplesInfo("ai21labs/AI21-Jamba-1.5-Mini",
|
||||||
extras={"tiny": "ai21labs/Jamba-tiny-dev"}), # noqa: E501
|
extras={"tiny": "ai21labs/Jamba-tiny-dev"}), # noqa: E501
|
||||||
"LlamaForCausalLM": _HfExamplesInfo("meta-llama/Llama-3.2-1B-Instruct"),
|
"LlamaForCausalLM": _HfExamplesInfo("meta-llama/Llama-3.2-1B-Instruct",
|
||||||
|
extras={"guard": "meta-llama/Llama-Guard-3-1B", # noqa: E501
|
||||||
|
"hermes": "NousResearch/Hermes-3-Llama-3.1-8B"}), # noqa: E501
|
||||||
"LLaMAForCausalLM": _HfExamplesInfo("decapoda-research/llama-7b-hf",
|
"LLaMAForCausalLM": _HfExamplesInfo("decapoda-research/llama-7b-hf",
|
||||||
is_available_online=False),
|
is_available_online=False),
|
||||||
"MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"),
|
"MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"),
|
||||||
|
@ -378,7 +380,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
|
||||||
# Therefore, we borrow the BartTokenizer from the original Bart model
|
# Therefore, we borrow the BartTokenizer from the original Bart model
|
||||||
"Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501
|
"Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501
|
||||||
tokenizer="Isotr0py/Florence-2-tokenizer",
|
tokenizer="Isotr0py/Florence-2-tokenizer",
|
||||||
trust_remote_code=True), # noqa: E501
|
trust_remote_code=True,), # noqa: E501
|
||||||
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
|
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
|
||||||
"WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501
|
"WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,6 +38,10 @@ from vllm.config import ModelConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
|
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
|
||||||
from vllm.multimodal.utils import MediaConnector
|
from vllm.multimodal.utils import MediaConnector
|
||||||
|
# yapf: disable
|
||||||
|
from vllm.transformers_utils.chat_templates import (
|
||||||
|
get_chat_template_fallback_path)
|
||||||
|
# yapf: enable
|
||||||
from vllm.transformers_utils.processor import cached_get_processor
|
from vllm.transformers_utils.processor import cached_get_processor
|
||||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||||
|
|
||||||
|
@ -325,11 +329,10 @@ def resolve_mistral_chat_template(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def resolve_hf_chat_template(
|
def resolve_hf_chat_template(
|
||||||
|
model_config: ModelConfig,
|
||||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||||
chat_template: Optional[str],
|
chat_template: Optional[str],
|
||||||
tools: Optional[list[dict[str, Any]]],
|
tools: Optional[list[dict[str, Any]]],
|
||||||
*,
|
|
||||||
trust_remote_code: bool,
|
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
# 1st priority: The given chat template
|
# 1st priority: The given chat template
|
||||||
if chat_template is not None:
|
if chat_template is not None:
|
||||||
|
@ -342,7 +345,7 @@ def resolve_hf_chat_template(
|
||||||
tokenizer.name_or_path,
|
tokenizer.name_or_path,
|
||||||
processor_cls=(PreTrainedTokenizer, PreTrainedTokenizerFast,
|
processor_cls=(PreTrainedTokenizer, PreTrainedTokenizerFast,
|
||||||
ProcessorMixin),
|
ProcessorMixin),
|
||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=model_config.trust_remote_code,
|
||||||
)
|
)
|
||||||
if isinstance(processor, ProcessorMixin) and \
|
if isinstance(processor, ProcessorMixin) and \
|
||||||
processor.chat_template is not None:
|
processor.chat_template is not None:
|
||||||
|
@ -358,22 +361,34 @@ def resolve_hf_chat_template(
|
||||||
logger.debug("Failed to load AutoTokenizer chat template for %s",
|
logger.debug("Failed to load AutoTokenizer chat template for %s",
|
||||||
tokenizer.name_or_path, exc_info=True)
|
tokenizer.name_or_path, exc_info=True)
|
||||||
|
|
||||||
return None
|
# 4th priority: Predefined fallbacks
|
||||||
|
path = get_chat_template_fallback_path(
|
||||||
|
model_type=model_config.hf_config.model_type,
|
||||||
|
tokenizer_name_or_path=model_config.tokenizer,
|
||||||
|
)
|
||||||
|
if path is not None:
|
||||||
|
logger.info("Loading chat template fallback for %s as there isn't one "
|
||||||
|
"defined on HF Hub.", tokenizer.name_or_path)
|
||||||
|
chat_template = load_chat_template(path)
|
||||||
|
else:
|
||||||
|
logger.debug("There is no chat template fallback for %s",
|
||||||
|
tokenizer.name_or_path)
|
||||||
|
|
||||||
|
return chat_template
|
||||||
|
|
||||||
|
|
||||||
def _resolve_chat_template_content_format(
|
def _resolve_chat_template_content_format(
|
||||||
|
model_config: ModelConfig,
|
||||||
chat_template: Optional[str],
|
chat_template: Optional[str],
|
||||||
tools: Optional[list[dict[str, Any]]],
|
tools: Optional[list[dict[str, Any]]],
|
||||||
given_format: ChatTemplateContentFormatOption,
|
given_format: ChatTemplateContentFormatOption,
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
*,
|
|
||||||
trust_remote_code: bool,
|
|
||||||
) -> _ChatTemplateContentFormat:
|
) -> _ChatTemplateContentFormat:
|
||||||
if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
|
if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
|
||||||
hf_chat_template = resolve_hf_chat_template(
|
hf_chat_template = resolve_hf_chat_template(
|
||||||
|
model_config,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
chat_template=chat_template,
|
chat_template=chat_template,
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
tools=tools,
|
tools=tools,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -413,19 +428,18 @@ def _log_chat_template_content_format(
|
||||||
|
|
||||||
|
|
||||||
def resolve_chat_template_content_format(
|
def resolve_chat_template_content_format(
|
||||||
|
model_config: ModelConfig,
|
||||||
chat_template: Optional[str],
|
chat_template: Optional[str],
|
||||||
tools: Optional[list[dict[str, Any]]],
|
tools: Optional[list[dict[str, Any]]],
|
||||||
given_format: ChatTemplateContentFormatOption,
|
given_format: ChatTemplateContentFormatOption,
|
||||||
tokenizer: AnyTokenizer,
|
tokenizer: AnyTokenizer,
|
||||||
*,
|
|
||||||
trust_remote_code: bool = False,
|
|
||||||
) -> _ChatTemplateContentFormat:
|
) -> _ChatTemplateContentFormat:
|
||||||
detected_format = _resolve_chat_template_content_format(
|
detected_format = _resolve_chat_template_content_format(
|
||||||
|
model_config,
|
||||||
chat_template,
|
chat_template,
|
||||||
tools,
|
tools,
|
||||||
given_format,
|
given_format,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
_log_chat_template_content_format(
|
_log_chat_template_content_format(
|
||||||
|
@ -1177,20 +1191,20 @@ def parse_chat_messages_futures(
|
||||||
|
|
||||||
|
|
||||||
def apply_hf_chat_template(
|
def apply_hf_chat_template(
|
||||||
|
model_config: ModelConfig,
|
||||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||||
conversation: list[ConversationMessage],
|
conversation: list[ConversationMessage],
|
||||||
chat_template: Optional[str],
|
chat_template: Optional[str],
|
||||||
tools: Optional[list[dict[str, Any]]],
|
tools: Optional[list[dict[str, Any]]],
|
||||||
*,
|
*,
|
||||||
trust_remote_code: bool = False,
|
|
||||||
tokenize: bool = False, # Different from HF's default
|
tokenize: bool = False, # Different from HF's default
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> str:
|
) -> str:
|
||||||
hf_chat_template = resolve_hf_chat_template(
|
hf_chat_template = resolve_hf_chat_template(
|
||||||
|
model_config,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
chat_template=chat_template,
|
chat_template=chat_template,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
trust_remote_code=trust_remote_code,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if hf_chat_template is None:
|
if hf_chat_template is None:
|
||||||
|
|
|
@ -726,11 +726,11 @@ class LLM:
|
||||||
tokenizer = self.get_tokenizer(lora_request)
|
tokenizer = self.get_tokenizer(lora_request)
|
||||||
model_config = self.llm_engine.get_model_config()
|
model_config = self.llm_engine.get_model_config()
|
||||||
resolved_content_format = resolve_chat_template_content_format(
|
resolved_content_format = resolve_chat_template_content_format(
|
||||||
|
model_config,
|
||||||
chat_template,
|
chat_template,
|
||||||
tools,
|
tools,
|
||||||
chat_template_content_format,
|
chat_template_content_format,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
trust_remote_code=model_config.trust_remote_code,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
_chat_template_kwargs: dict[str, Any] = dict(
|
_chat_template_kwargs: dict[str, Any] = dict(
|
||||||
|
@ -762,8 +762,8 @@ class LLM:
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
prompt_str = apply_hf_chat_template(
|
prompt_str = apply_hf_chat_template(
|
||||||
|
model_config,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
trust_remote_code=model_config.trust_remote_code,
|
|
||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
**_chat_template_kwargs,
|
**_chat_template_kwargs,
|
||||||
)
|
)
|
||||||
|
|
|
@ -937,10 +937,11 @@ async def init_app_state(
|
||||||
chat_template=resolved_chat_template)
|
chat_template=resolved_chat_template)
|
||||||
else:
|
else:
|
||||||
hf_chat_template = resolve_hf_chat_template(
|
hf_chat_template = resolve_hf_chat_template(
|
||||||
|
vllm_config.model_config,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
chat_template=None,
|
chat_template=None,
|
||||||
tools=None,
|
tools=None,
|
||||||
trust_remote_code=model_config.trust_remote_code)
|
)
|
||||||
|
|
||||||
if hf_chat_template != resolved_chat_template:
|
if hf_chat_template != resolved_chat_template:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
|
@ -394,11 +394,11 @@ class OpenAIServing:
|
||||||
model_config = self.model_config
|
model_config = self.model_config
|
||||||
|
|
||||||
resolved_content_format = resolve_chat_template_content_format(
|
resolved_content_format = resolve_chat_template_content_format(
|
||||||
|
model_config,
|
||||||
chat_template,
|
chat_template,
|
||||||
tool_dicts,
|
tool_dicts,
|
||||||
chat_template_content_format,
|
chat_template_content_format,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
trust_remote_code=model_config.trust_remote_code,
|
|
||||||
)
|
)
|
||||||
conversation, mm_data_future = parse_chat_messages_futures(
|
conversation, mm_data_future = parse_chat_messages_futures(
|
||||||
messages,
|
messages,
|
||||||
|
@ -425,8 +425,8 @@ class OpenAIServing:
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
request_prompt = apply_hf_chat_template(
|
request_prompt = apply_hf_chat_template(
|
||||||
|
model_config,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
trust_remote_code=model_config.trust_remote_code,
|
|
||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
**_chat_template_kwargs,
|
**_chat_template_kwargs,
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,4 @@
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from .registry import get_chat_template_fallback_path
|
||||||
|
|
||||||
|
__all__ = ["get_chat_template_fallback_path"]
|
|
@ -0,0 +1,59 @@
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
|
logger = init_logger(__file__)
|
||||||
|
|
||||||
|
CHAT_TEMPLATES_DIR = Path(__file__).parent
|
||||||
|
|
||||||
|
ChatTemplatePath = Union[Path, Callable[[str], Optional[Path]]]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_qwen_chat_template_fallback(
|
||||||
|
tokenizer_name_or_path: str) -> Optional[Path]:
|
||||||
|
if tokenizer_name_or_path.endswith("-Chat"):
|
||||||
|
return CHAT_TEMPLATES_DIR / "template_chatml.jinja"
|
||||||
|
|
||||||
|
return CHAT_TEMPLATES_DIR / "template_basic.jinja"
|
||||||
|
|
||||||
|
|
||||||
|
# yapf: disable
|
||||||
|
_MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: dict[str, ChatTemplatePath] = {
|
||||||
|
"blip-2": CHAT_TEMPLATES_DIR / "template_blip2.jinja",
|
||||||
|
"chameleon": CHAT_TEMPLATES_DIR / "template_basic.jinja",
|
||||||
|
"deepseek_vl_v2": CHAT_TEMPLATES_DIR / "template_deepseek_vl2.jinja",
|
||||||
|
"florence2": CHAT_TEMPLATES_DIR / "template_basic.jinja",
|
||||||
|
"fuyu": CHAT_TEMPLATES_DIR / "template_fuyu.jinja",
|
||||||
|
"paligemma": CHAT_TEMPLATES_DIR / "template_basic.jinja",
|
||||||
|
"qwen": _get_qwen_chat_template_fallback,
|
||||||
|
}
|
||||||
|
# yapf: enable
|
||||||
|
|
||||||
|
|
||||||
|
def register_chat_template_fallback_path(
|
||||||
|
model_type: str,
|
||||||
|
chat_template: ChatTemplatePath,
|
||||||
|
) -> None:
|
||||||
|
if model_type in _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK:
|
||||||
|
logger.warning(
|
||||||
|
"Model type %s already has a chat template registered. "
|
||||||
|
"It will be overwritten by the new chat template %s.", model_type,
|
||||||
|
chat_template)
|
||||||
|
|
||||||
|
_MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK[model_type] = chat_template
|
||||||
|
|
||||||
|
|
||||||
|
def get_chat_template_fallback_path(
|
||||||
|
model_type: str,
|
||||||
|
tokenizer_name_or_path: str,
|
||||||
|
) -> Optional[Path]:
|
||||||
|
chat_template = _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK.get(model_type)
|
||||||
|
if callable(chat_template):
|
||||||
|
chat_template = chat_template(tokenizer_name_or_path)
|
||||||
|
|
||||||
|
if chat_template is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return chat_template
|
Loading…
Reference in New Issue