[Frontend] Chat template fallbacks for multimodal models (#17805)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-05-08 14:05:54 +08:00 committed by GitHub
parent 843b222723
commit 96722aa81d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 219 additions and 52 deletions

View File

@ -213,10 +213,13 @@ Our OpenAI-compatible server accepts multi-modal data via the [Chat Completions
:::{important}
A chat template is **required** to use Chat Completions API.
For HF format models, the default chat template is defined inside `chat_template.json` or `tokenizer_config.json`.
Although most models come with a chat template, for others you have to define one yourself.
The chat template can be inferred based on the documentation on the model's HuggingFace repo.
For example, DeepSeek-VL2 requires a chat template that can be found here: <gh-file:examples/template_deepseek_vl2.jinja>
If no default chat template is available, we will first look for a built-in fallback in <gh-file:vllm/transformers_utils/chat_templates/registry.py>.
If no fallback is available, an error is raised and you have to provide the chat template manually via the `--chat-template` argument.
For certain models, we provide alternative chat templates inside <gh-dir:vllm/examples>.
For example, VLM2Vec uses <gh-file:examples/template_vlm2vec.jinja> which is different from the default one for Phi-3-Vision.
:::
### Image Inputs

View File

@ -1,3 +0,0 @@
{%- for message in messages -%}
{{- message['content'] -}}
{%- endfor -%}

View File

@ -1,3 +0,0 @@
{%- for message in messages -%}
{{- message['content'] -}}
{%- endfor -%}

View File

@ -1,3 +0,0 @@
{%- for message in messages -%}
{{- message['content'] -}}
{%- endfor -%}

View File

@ -2,11 +2,13 @@
import pytest
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
load_chat_template)
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.transformers_utils.tokenizer import get_tokenizer
from ...models.registry import HF_EXAMPLE_MODELS
from ...utils import VLLM_PATH
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
@ -91,8 +93,22 @@ def test_no_load_chat_template_literallike():
MODEL_TEMPLATE_GENERATON_OUTPUT)
def test_get_gen_prompt(model, template, add_generation_prompt,
continue_final_message, expected_output):
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
)
# Initialize the tokenizer
tokenizer = get_tokenizer(tokenizer_name=model)
tokenizer = get_tokenizer(
tokenizer_name=model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code,
)
template_content = load_chat_template(chat_template=template)
# Create a mock request object using keyword arguments
@ -106,8 +122,8 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
# Call the function and get the result
result = apply_hf_chat_template(
model_config,
tokenizer,
trust_remote_code=True,
conversation=mock_request.messages,
chat_template=mock_request.chat_template or template_content,
tools=None,

View File

@ -4,8 +4,6 @@ import warnings
from typing import Optional
import pytest
from packaging.version import Version
from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.assets.image import ImageAsset
from vllm.config import ModelConfig
@ -19,6 +17,7 @@ from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import encode_image_base64
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import VLLM_PATH
EXAMPLES_DIR = VLLM_PATH / "examples"
@ -772,6 +771,7 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
enable_lora=False,
max_num_seqs=5,
max_input_length=None,
trust_remote_code=model_config.trust_remote_code,
)
tokenizer = tokenizer_group.tokenizer
@ -793,8 +793,8 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
)
vllm_result = apply_hf_chat_template(
model_config,
tokenizer,
trust_remote_code=model_config.trust_remote_code,
conversation=conversation,
chat_template=None,
tools=None,
@ -813,6 +813,16 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
@pytest.mark.parametrize("use_tools", [True, False])
def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
"""checks that chat_template is a dict type for HF models."""
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
)
# Build the tokenizer group and grab the underlying tokenizer
tokenizer_group = TokenizerGroup(
@ -820,6 +830,7 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
enable_lora=False,
max_num_seqs=5,
max_input_length=None,
trust_remote_code=model_config.trust_remote_code,
)
tokenizer = tokenizer_group.tokenizer
@ -834,10 +845,10 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
# Test detecting the tokenizer's chat_template
chat_template = resolve_hf_chat_template(
model_config,
tokenizer,
chat_template=None,
tools=tools,
trust_remote_code=True,
)
assert isinstance(chat_template, str)
@ -857,24 +868,32 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
)
# yapf: enable
def test_resolve_content_format_hf_defined(model, expected_format):
if model == QWEN25VL_MODEL_ID and Version(TRANSFORMERS_VERSION) < Version(
"4.49.0"):
pytest.skip("Qwen2.5-VL requires transformers>=4.49.0")
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
)
tokenizer_group = TokenizerGroup(
model,
enable_lora=False,
max_num_seqs=5,
max_input_length=None,
trust_remote_code=model_config.trust_remote_code,
)
tokenizer = tokenizer_group.tokenizer
# Test detecting the tokenizer's chat_template
chat_template = resolve_hf_chat_template(
model_config,
tokenizer,
chat_template=None,
tools=None,
trust_remote_code=True,
)
assert isinstance(chat_template, str)
@ -884,11 +903,70 @@ def test_resolve_content_format_hf_defined(model, expected_format):
print(_try_extract_ast(chat_template))
resolved_format = resolve_chat_template_content_format(
model_config,
None, # Test detecting the tokenizer's chat_template
None,
"auto",
tokenizer,
)
assert resolved_format == expected_format
# yapf: disable
@pytest.mark.parametrize(
("model", "expected_format"),
[("Salesforce/blip2-opt-2.7b", "string"),
("facebook/chameleon-7b", "string"),
("deepseek-ai/deepseek-vl2-tiny", "string"),
("microsoft/Florence-2-base", "string"),
("adept/fuyu-8b", "string"),
("google/paligemma-3b-mix-224", "string"),
("Qwen/Qwen-VL", "string"),
("Qwen/Qwen-VL-Chat", "string")],
)
# yapf: enable
def test_resolve_content_format_fallbacks(model, expected_format):
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
)
tokenizer_group = TokenizerGroup(
model_config.tokenizer,
enable_lora=False,
max_num_seqs=5,
max_input_length=None,
trust_remote_code=model_config.trust_remote_code,
)
tokenizer = tokenizer_group.tokenizer
# Test detecting the tokenizer's chat_template
chat_template = resolve_hf_chat_template(
model_config,
tokenizer,
chat_template=None,
tools=None,
)
assert isinstance(chat_template, str)
print("[TEXT]")
print(chat_template)
print("[AST]")
print(_try_extract_ast(chat_template))
resolved_format = resolve_chat_template_content_format(
model_config,
None, # Test detecting the tokenizer's chat_template
None,
"auto",
tokenizer,
trust_remote_code=True,
)
assert resolved_format == expected_format
@ -899,22 +977,14 @@ def test_resolve_content_format_hf_defined(model, expected_format):
("template_path", "expected_format"),
[("template_alpaca.jinja", "string"),
("template_baichuan.jinja", "string"),
("template_blip2.jinja", "string"),
("template_chameleon.jinja", "string"),
("template_chatglm.jinja", "string"),
("template_chatglm2.jinja", "string"),
("template_chatml.jinja", "string"),
("template_deepseek_vl2.jinja", "string"),
("template_dse_qwen2_vl.jinja", "openai"),
("template_falcon_180b.jinja", "string"),
("template_falcon.jinja", "string"),
("template_florence2.jinja", "string"),
("template_fuyu.jinja", "string"),
("template_inkbot.jinja", "string"),
("template_paligemma.jinja", "string"),
("template_teleflm.jinja", "string"),
("template_qwen_vl.jinja", "string"),
("template_qwen_vl_chat.jinja", "string"),
("template_vlm2vec.jinja", "openai"),
("tool_chat_template_granite_20b_fc.jinja", "string"),
("tool_chat_template_hermes.jinja", "string"),
@ -926,11 +996,18 @@ def test_resolve_content_format_hf_defined(model, expected_format):
)
# yapf: enable
def test_resolve_content_format_examples(template_path, expected_format):
model_config = ModelConfig(
PHI3V_MODEL_ID, # Dummy
tokenizer=PHI3V_MODEL_ID, # Dummy
trust_remote_code=True,
)
tokenizer_group = TokenizerGroup(
PHI3V_MODEL_ID,
PHI3V_MODEL_ID, # Dummy
enable_lora=False,
max_num_seqs=5,
max_input_length=None,
trust_remote_code=model_config.trust_remote_code,
)
dummy_tokenizer = tokenizer_group.tokenizer
dummy_tokenizer.chat_template = None
@ -944,11 +1021,11 @@ def test_resolve_content_format_examples(template_path, expected_format):
print(_try_extract_ast(chat_template))
resolved_format = resolve_chat_template_content_format(
model_config,
chat_template,
None,
"auto",
dummy_tokenizer,
trust_remote_code=True,
)
assert resolved_format == expected_format

View File

@ -182,7 +182,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"),
"JambaForCausalLM": _HfExamplesInfo("ai21labs/AI21-Jamba-1.5-Mini",
extras={"tiny": "ai21labs/Jamba-tiny-dev"}), # noqa: E501
"LlamaForCausalLM": _HfExamplesInfo("meta-llama/Llama-3.2-1B-Instruct"),
"LlamaForCausalLM": _HfExamplesInfo("meta-llama/Llama-3.2-1B-Instruct",
extras={"guard": "meta-llama/Llama-Guard-3-1B", # noqa: E501
"hermes": "NousResearch/Hermes-3-Llama-3.1-8B"}), # noqa: E501
"LLaMAForCausalLM": _HfExamplesInfo("decapoda-research/llama-7b-hf",
is_available_online=False),
"MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"),
@ -378,7 +380,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
# Therefore, we borrow the BartTokenizer from the original Bart model
"Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501
tokenizer="Isotr0py/Florence-2-tokenizer",
trust_remote_code=True), # noqa: E501
trust_remote_code=True,), # noqa: E501
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
"WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501
}

View File

@ -38,6 +38,10 @@ from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
from vllm.multimodal.utils import MediaConnector
# yapf: disable
from vllm.transformers_utils.chat_templates import (
get_chat_template_fallback_path)
# yapf: enable
from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
@ -325,11 +329,10 @@ def resolve_mistral_chat_template(
return None
def resolve_hf_chat_template(
model_config: ModelConfig,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
*,
trust_remote_code: bool,
) -> Optional[str]:
# 1st priority: The given chat template
if chat_template is not None:
@ -342,7 +345,7 @@ def resolve_hf_chat_template(
tokenizer.name_or_path,
processor_cls=(PreTrainedTokenizer, PreTrainedTokenizerFast,
ProcessorMixin),
trust_remote_code=trust_remote_code,
trust_remote_code=model_config.trust_remote_code,
)
if isinstance(processor, ProcessorMixin) and \
processor.chat_template is not None:
@ -358,22 +361,34 @@ def resolve_hf_chat_template(
logger.debug("Failed to load AutoTokenizer chat template for %s",
tokenizer.name_or_path, exc_info=True)
return None
# 4th priority: Predefined fallbacks
path = get_chat_template_fallback_path(
model_type=model_config.hf_config.model_type,
tokenizer_name_or_path=model_config.tokenizer,
)
if path is not None:
logger.info("Loading chat template fallback for %s as there isn't one "
"defined on HF Hub.", tokenizer.name_or_path)
chat_template = load_chat_template(path)
else:
logger.debug("There is no chat template fallback for %s",
tokenizer.name_or_path)
return chat_template
def _resolve_chat_template_content_format(
model_config: ModelConfig,
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
given_format: ChatTemplateContentFormatOption,
tokenizer: AnyTokenizer,
*,
trust_remote_code: bool,
) -> _ChatTemplateContentFormat:
if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
hf_chat_template = resolve_hf_chat_template(
model_config,
tokenizer,
chat_template=chat_template,
trust_remote_code=trust_remote_code,
tools=tools,
)
else:
@ -413,19 +428,18 @@ def _log_chat_template_content_format(
def resolve_chat_template_content_format(
model_config: ModelConfig,
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
given_format: ChatTemplateContentFormatOption,
tokenizer: AnyTokenizer,
*,
trust_remote_code: bool = False,
) -> _ChatTemplateContentFormat:
detected_format = _resolve_chat_template_content_format(
model_config,
chat_template,
tools,
given_format,
tokenizer,
trust_remote_code=trust_remote_code,
)
_log_chat_template_content_format(
@ -1177,20 +1191,20 @@ def parse_chat_messages_futures(
def apply_hf_chat_template(
model_config: ModelConfig,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
conversation: list[ConversationMessage],
chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]],
*,
trust_remote_code: bool = False,
tokenize: bool = False, # Different from HF's default
**kwargs: Any,
) -> str:
hf_chat_template = resolve_hf_chat_template(
model_config,
tokenizer,
chat_template=chat_template,
tools=tools,
trust_remote_code=trust_remote_code,
)
if hf_chat_template is None:

View File

@ -726,11 +726,11 @@ class LLM:
tokenizer = self.get_tokenizer(lora_request)
model_config = self.llm_engine.get_model_config()
resolved_content_format = resolve_chat_template_content_format(
model_config,
chat_template,
tools,
chat_template_content_format,
tokenizer,
trust_remote_code=model_config.trust_remote_code,
)
_chat_template_kwargs: dict[str, Any] = dict(
@ -762,8 +762,8 @@ class LLM:
)
else:
prompt_str = apply_hf_chat_template(
model_config,
tokenizer,
trust_remote_code=model_config.trust_remote_code,
conversation=conversation,
**_chat_template_kwargs,
)

View File

@ -937,10 +937,11 @@ async def init_app_state(
chat_template=resolved_chat_template)
else:
hf_chat_template = resolve_hf_chat_template(
vllm_config.model_config,
tokenizer,
chat_template=None,
tools=None,
trust_remote_code=model_config.trust_remote_code)
)
if hf_chat_template != resolved_chat_template:
logger.warning(

View File

@ -394,11 +394,11 @@ class OpenAIServing:
model_config = self.model_config
resolved_content_format = resolve_chat_template_content_format(
model_config,
chat_template,
tool_dicts,
chat_template_content_format,
tokenizer,
trust_remote_code=model_config.trust_remote_code,
)
conversation, mm_data_future = parse_chat_messages_futures(
messages,
@ -425,8 +425,8 @@ class OpenAIServing:
)
else:
request_prompt = apply_hf_chat_template(
model_config,
tokenizer,
trust_remote_code=model_config.trust_remote_code,
conversation=conversation,
**_chat_template_kwargs,
)

View File

@ -0,0 +1,4 @@
# SPDX-License-Identifier: Apache-2.0
from .registry import get_chat_template_fallback_path
__all__ = ["get_chat_template_fallback_path"]

View File

@ -0,0 +1,59 @@
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path
from typing import Callable, Optional, Union
from vllm.logger import init_logger
logger = init_logger(__file__)
CHAT_TEMPLATES_DIR = Path(__file__).parent
ChatTemplatePath = Union[Path, Callable[[str], Optional[Path]]]
def _get_qwen_chat_template_fallback(
tokenizer_name_or_path: str) -> Optional[Path]:
if tokenizer_name_or_path.endswith("-Chat"):
return CHAT_TEMPLATES_DIR / "template_chatml.jinja"
return CHAT_TEMPLATES_DIR / "template_basic.jinja"
# yapf: disable
_MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: dict[str, ChatTemplatePath] = {
"blip-2": CHAT_TEMPLATES_DIR / "template_blip2.jinja",
"chameleon": CHAT_TEMPLATES_DIR / "template_basic.jinja",
"deepseek_vl_v2": CHAT_TEMPLATES_DIR / "template_deepseek_vl2.jinja",
"florence2": CHAT_TEMPLATES_DIR / "template_basic.jinja",
"fuyu": CHAT_TEMPLATES_DIR / "template_fuyu.jinja",
"paligemma": CHAT_TEMPLATES_DIR / "template_basic.jinja",
"qwen": _get_qwen_chat_template_fallback,
}
# yapf: enable
def register_chat_template_fallback_path(
model_type: str,
chat_template: ChatTemplatePath,
) -> None:
if model_type in _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK:
logger.warning(
"Model type %s already has a chat template registered. "
"It will be overwritten by the new chat template %s.", model_type,
chat_template)
_MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK[model_type] = chat_template
def get_chat_template_fallback_path(
model_type: str,
tokenizer_name_or_path: str,
) -> Optional[Path]:
chat_template = _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK.get(model_type)
if callable(chat_template):
chat_template = chat_template(tokenizer_name_or_path)
if chat_template is None:
return None
return chat_template