[Frontend] Chat template fallbacks for multimodal models (#17805)

Signed-off-by: DarkLight1337 <tlleungac@connect.ust.hk>
This commit is contained in:
Cyrus Leung 2025-05-08 14:05:54 +08:00 committed by GitHub
parent 843b222723
commit 96722aa81d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 219 additions and 52 deletions

View File

@ -213,10 +213,13 @@ Our OpenAI-compatible server accepts multi-modal data via the [Chat Completions
:::{important} :::{important}
A chat template is **required** to use Chat Completions API. A chat template is **required** to use Chat Completions API.
For HF format models, the default chat template is defined inside `chat_template.json` or `tokenizer_config.json`.
Although most models come with a chat template, for others you have to define one yourself. If no default chat template is available, we will first look for a built-in fallback in <gh-file:vllm/transformers_utils/chat_templates/registry.py>.
The chat template can be inferred based on the documentation on the model's HuggingFace repo. If no fallback is available, an error is raised and you have to provide the chat template manually via the `--chat-template` argument.
For example, DeepSeek-VL2 requires a chat template that can be found here: <gh-file:examples/template_deepseek_vl2.jinja>
For certain models, we provide alternative chat templates inside <gh-dir:vllm/examples>.
For example, VLM2Vec uses <gh-file:examples/template_vlm2vec.jinja> which is different from the default one for Phi-3-Vision.
::: :::
### Image Inputs ### Image Inputs

View File

@ -1,3 +0,0 @@
{%- for message in messages -%}
{{- message['content'] -}}
{%- endfor -%}

View File

@ -1,3 +0,0 @@
{%- for message in messages -%}
{{- message['content'] -}}
{%- endfor -%}

View File

@ -1,3 +0,0 @@
{%- for message in messages -%}
{{- message['content'] -}}
{%- endfor -%}

View File

@ -2,11 +2,13 @@
import pytest import pytest
from vllm.config import ModelConfig
from vllm.entrypoints.chat_utils import (apply_hf_chat_template, from vllm.entrypoints.chat_utils import (apply_hf_chat_template,
load_chat_template) load_chat_template)
from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.entrypoints.openai.protocol import ChatCompletionRequest
from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.transformers_utils.tokenizer import get_tokenizer
from ...models.registry import HF_EXAMPLE_MODELS
from ...utils import VLLM_PATH from ...utils import VLLM_PATH
chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja" chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja"
@ -91,8 +93,22 @@ def test_no_load_chat_template_literallike():
MODEL_TEMPLATE_GENERATON_OUTPUT) MODEL_TEMPLATE_GENERATON_OUTPUT)
def test_get_gen_prompt(model, template, add_generation_prompt, def test_get_gen_prompt(model, template, add_generation_prompt,
continue_final_message, expected_output): continue_final_message, expected_output):
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
)
# Initialize the tokenizer # Initialize the tokenizer
tokenizer = get_tokenizer(tokenizer_name=model) tokenizer = get_tokenizer(
tokenizer_name=model_config.tokenizer,
trust_remote_code=model_config.trust_remote_code,
)
template_content = load_chat_template(chat_template=template) template_content = load_chat_template(chat_template=template)
# Create a mock request object using keyword arguments # Create a mock request object using keyword arguments
@ -106,8 +122,8 @@ def test_get_gen_prompt(model, template, add_generation_prompt,
# Call the function and get the result # Call the function and get the result
result = apply_hf_chat_template( result = apply_hf_chat_template(
model_config,
tokenizer, tokenizer,
trust_remote_code=True,
conversation=mock_request.messages, conversation=mock_request.messages,
chat_template=mock_request.chat_template or template_content, chat_template=mock_request.chat_template or template_content,
tools=None, tools=None,

View File

@ -4,8 +4,6 @@ import warnings
from typing import Optional from typing import Optional
import pytest import pytest
from packaging.version import Version
from transformers import __version__ as TRANSFORMERS_VERSION
from vllm.assets.image import ImageAsset from vllm.assets.image import ImageAsset
from vllm.config import ModelConfig from vllm.config import ModelConfig
@ -19,6 +17,7 @@ from vllm.multimodal import MultiModalDataDict
from vllm.multimodal.utils import encode_image_base64 from vllm.multimodal.utils import encode_image_base64
from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import VLLM_PATH from ..utils import VLLM_PATH
EXAMPLES_DIR = VLLM_PATH / "examples" EXAMPLES_DIR = VLLM_PATH / "examples"
@ -772,6 +771,7 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
enable_lora=False, enable_lora=False,
max_num_seqs=5, max_num_seqs=5,
max_input_length=None, max_input_length=None,
trust_remote_code=model_config.trust_remote_code,
) )
tokenizer = tokenizer_group.tokenizer tokenizer = tokenizer_group.tokenizer
@ -793,8 +793,8 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
) )
vllm_result = apply_hf_chat_template( vllm_result = apply_hf_chat_template(
model_config,
tokenizer, tokenizer,
trust_remote_code=model_config.trust_remote_code,
conversation=conversation, conversation=conversation,
chat_template=None, chat_template=None,
tools=None, tools=None,
@ -813,6 +813,16 @@ def test_multimodal_image_parsing_matches_hf(model, image_url):
@pytest.mark.parametrize("use_tools", [True, False]) @pytest.mark.parametrize("use_tools", [True, False])
def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
"""checks that chat_template is a dict type for HF models.""" """checks that chat_template is a dict type for HF models."""
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
)
# Build the tokenizer group and grab the underlying tokenizer # Build the tokenizer group and grab the underlying tokenizer
tokenizer_group = TokenizerGroup( tokenizer_group = TokenizerGroup(
@ -820,6 +830,7 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
enable_lora=False, enable_lora=False,
max_num_seqs=5, max_num_seqs=5,
max_input_length=None, max_input_length=None,
trust_remote_code=model_config.trust_remote_code,
) )
tokenizer = tokenizer_group.tokenizer tokenizer = tokenizer_group.tokenizer
@ -834,10 +845,10 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
# Test detecting the tokenizer's chat_template # Test detecting the tokenizer's chat_template
chat_template = resolve_hf_chat_template( chat_template = resolve_hf_chat_template(
model_config,
tokenizer, tokenizer,
chat_template=None, chat_template=None,
tools=tools, tools=tools,
trust_remote_code=True,
) )
assert isinstance(chat_template, str) assert isinstance(chat_template, str)
@ -857,24 +868,32 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools):
) )
# yapf: enable # yapf: enable
def test_resolve_content_format_hf_defined(model, expected_format): def test_resolve_content_format_hf_defined(model, expected_format):
if model == QWEN25VL_MODEL_ID and Version(TRANSFORMERS_VERSION) < Version( model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
"4.49.0"): model_info.check_available_online(on_fail="skip")
pytest.skip("Qwen2.5-VL requires transformers>=4.49.0")
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
)
tokenizer_group = TokenizerGroup( tokenizer_group = TokenizerGroup(
model, model,
enable_lora=False, enable_lora=False,
max_num_seqs=5, max_num_seqs=5,
max_input_length=None, max_input_length=None,
trust_remote_code=model_config.trust_remote_code,
) )
tokenizer = tokenizer_group.tokenizer tokenizer = tokenizer_group.tokenizer
# Test detecting the tokenizer's chat_template # Test detecting the tokenizer's chat_template
chat_template = resolve_hf_chat_template( chat_template = resolve_hf_chat_template(
model_config,
tokenizer, tokenizer,
chat_template=None, chat_template=None,
tools=None, tools=None,
trust_remote_code=True,
) )
assert isinstance(chat_template, str) assert isinstance(chat_template, str)
@ -884,11 +903,70 @@ def test_resolve_content_format_hf_defined(model, expected_format):
print(_try_extract_ast(chat_template)) print(_try_extract_ast(chat_template))
resolved_format = resolve_chat_template_content_format( resolved_format = resolve_chat_template_content_format(
model_config,
None, # Test detecting the tokenizer's chat_template
None,
"auto",
tokenizer,
)
assert resolved_format == expected_format
# yapf: disable
@pytest.mark.parametrize(
("model", "expected_format"),
[("Salesforce/blip2-opt-2.7b", "string"),
("facebook/chameleon-7b", "string"),
("deepseek-ai/deepseek-vl2-tiny", "string"),
("microsoft/Florence-2-base", "string"),
("adept/fuyu-8b", "string"),
("google/paligemma-3b-mix-224", "string"),
("Qwen/Qwen-VL", "string"),
("Qwen/Qwen-VL-Chat", "string")],
)
# yapf: enable
def test_resolve_content_format_fallbacks(model, expected_format):
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_config = ModelConfig(
model,
tokenizer=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
trust_remote_code=model_info.trust_remote_code,
hf_overrides=model_info.hf_overrides,
)
tokenizer_group = TokenizerGroup(
model_config.tokenizer,
enable_lora=False,
max_num_seqs=5,
max_input_length=None,
trust_remote_code=model_config.trust_remote_code,
)
tokenizer = tokenizer_group.tokenizer
# Test detecting the tokenizer's chat_template
chat_template = resolve_hf_chat_template(
model_config,
tokenizer,
chat_template=None,
tools=None,
)
assert isinstance(chat_template, str)
print("[TEXT]")
print(chat_template)
print("[AST]")
print(_try_extract_ast(chat_template))
resolved_format = resolve_chat_template_content_format(
model_config,
None, # Test detecting the tokenizer's chat_template None, # Test detecting the tokenizer's chat_template
None, None,
"auto", "auto",
tokenizer, tokenizer,
trust_remote_code=True,
) )
assert resolved_format == expected_format assert resolved_format == expected_format
@ -899,22 +977,14 @@ def test_resolve_content_format_hf_defined(model, expected_format):
("template_path", "expected_format"), ("template_path", "expected_format"),
[("template_alpaca.jinja", "string"), [("template_alpaca.jinja", "string"),
("template_baichuan.jinja", "string"), ("template_baichuan.jinja", "string"),
("template_blip2.jinja", "string"),
("template_chameleon.jinja", "string"),
("template_chatglm.jinja", "string"), ("template_chatglm.jinja", "string"),
("template_chatglm2.jinja", "string"), ("template_chatglm2.jinja", "string"),
("template_chatml.jinja", "string"), ("template_chatml.jinja", "string"),
("template_deepseek_vl2.jinja", "string"),
("template_dse_qwen2_vl.jinja", "openai"), ("template_dse_qwen2_vl.jinja", "openai"),
("template_falcon_180b.jinja", "string"), ("template_falcon_180b.jinja", "string"),
("template_falcon.jinja", "string"), ("template_falcon.jinja", "string"),
("template_florence2.jinja", "string"),
("template_fuyu.jinja", "string"),
("template_inkbot.jinja", "string"), ("template_inkbot.jinja", "string"),
("template_paligemma.jinja", "string"),
("template_teleflm.jinja", "string"), ("template_teleflm.jinja", "string"),
("template_qwen_vl.jinja", "string"),
("template_qwen_vl_chat.jinja", "string"),
("template_vlm2vec.jinja", "openai"), ("template_vlm2vec.jinja", "openai"),
("tool_chat_template_granite_20b_fc.jinja", "string"), ("tool_chat_template_granite_20b_fc.jinja", "string"),
("tool_chat_template_hermes.jinja", "string"), ("tool_chat_template_hermes.jinja", "string"),
@ -926,11 +996,18 @@ def test_resolve_content_format_hf_defined(model, expected_format):
) )
# yapf: enable # yapf: enable
def test_resolve_content_format_examples(template_path, expected_format): def test_resolve_content_format_examples(template_path, expected_format):
model_config = ModelConfig(
PHI3V_MODEL_ID, # Dummy
tokenizer=PHI3V_MODEL_ID, # Dummy
trust_remote_code=True,
)
tokenizer_group = TokenizerGroup( tokenizer_group = TokenizerGroup(
PHI3V_MODEL_ID, PHI3V_MODEL_ID, # Dummy
enable_lora=False, enable_lora=False,
max_num_seqs=5, max_num_seqs=5,
max_input_length=None, max_input_length=None,
trust_remote_code=model_config.trust_remote_code,
) )
dummy_tokenizer = tokenizer_group.tokenizer dummy_tokenizer = tokenizer_group.tokenizer
dummy_tokenizer.chat_template = None dummy_tokenizer.chat_template = None
@ -944,11 +1021,11 @@ def test_resolve_content_format_examples(template_path, expected_format):
print(_try_extract_ast(chat_template)) print(_try_extract_ast(chat_template))
resolved_format = resolve_chat_template_content_format( resolved_format = resolve_chat_template_content_format(
model_config,
chat_template, chat_template,
None, None,
"auto", "auto",
dummy_tokenizer, dummy_tokenizer,
trust_remote_code=True,
) )
assert resolved_format == expected_format assert resolved_format == expected_format

View File

@ -182,7 +182,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
"JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"), "JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"),
"JambaForCausalLM": _HfExamplesInfo("ai21labs/AI21-Jamba-1.5-Mini", "JambaForCausalLM": _HfExamplesInfo("ai21labs/AI21-Jamba-1.5-Mini",
extras={"tiny": "ai21labs/Jamba-tiny-dev"}), # noqa: E501 extras={"tiny": "ai21labs/Jamba-tiny-dev"}), # noqa: E501
"LlamaForCausalLM": _HfExamplesInfo("meta-llama/Llama-3.2-1B-Instruct"), "LlamaForCausalLM": _HfExamplesInfo("meta-llama/Llama-3.2-1B-Instruct",
extras={"guard": "meta-llama/Llama-Guard-3-1B", # noqa: E501
"hermes": "NousResearch/Hermes-3-Llama-3.1-8B"}), # noqa: E501
"LLaMAForCausalLM": _HfExamplesInfo("decapoda-research/llama-7b-hf", "LLaMAForCausalLM": _HfExamplesInfo("decapoda-research/llama-7b-hf",
is_available_online=False), is_available_online=False),
"MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"), "MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"),
@ -378,7 +380,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
# Therefore, we borrow the BartTokenizer from the original Bart model # Therefore, we borrow the BartTokenizer from the original Bart model
"Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501 "Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501
tokenizer="Isotr0py/Florence-2-tokenizer", tokenizer="Isotr0py/Florence-2-tokenizer",
trust_remote_code=True), # noqa: E501 trust_remote_code=True,), # noqa: E501
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501 "MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
"WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501 "WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501
} }

View File

@ -38,6 +38,10 @@ from vllm.config import ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict
from vllm.multimodal.utils import MediaConnector from vllm.multimodal.utils import MediaConnector
# yapf: disable
from vllm.transformers_utils.chat_templates import (
get_chat_template_fallback_path)
# yapf: enable
from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.processor import cached_get_processor
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
@ -325,11 +329,10 @@ def resolve_mistral_chat_template(
return None return None
def resolve_hf_chat_template( def resolve_hf_chat_template(
model_config: ModelConfig,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
chat_template: Optional[str], chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]], tools: Optional[list[dict[str, Any]]],
*,
trust_remote_code: bool,
) -> Optional[str]: ) -> Optional[str]:
# 1st priority: The given chat template # 1st priority: The given chat template
if chat_template is not None: if chat_template is not None:
@ -342,7 +345,7 @@ def resolve_hf_chat_template(
tokenizer.name_or_path, tokenizer.name_or_path,
processor_cls=(PreTrainedTokenizer, PreTrainedTokenizerFast, processor_cls=(PreTrainedTokenizer, PreTrainedTokenizerFast,
ProcessorMixin), ProcessorMixin),
trust_remote_code=trust_remote_code, trust_remote_code=model_config.trust_remote_code,
) )
if isinstance(processor, ProcessorMixin) and \ if isinstance(processor, ProcessorMixin) and \
processor.chat_template is not None: processor.chat_template is not None:
@ -358,22 +361,34 @@ def resolve_hf_chat_template(
logger.debug("Failed to load AutoTokenizer chat template for %s", logger.debug("Failed to load AutoTokenizer chat template for %s",
tokenizer.name_or_path, exc_info=True) tokenizer.name_or_path, exc_info=True)
return None # 4th priority: Predefined fallbacks
path = get_chat_template_fallback_path(
model_type=model_config.hf_config.model_type,
tokenizer_name_or_path=model_config.tokenizer,
)
if path is not None:
logger.info("Loading chat template fallback for %s as there isn't one "
"defined on HF Hub.", tokenizer.name_or_path)
chat_template = load_chat_template(path)
else:
logger.debug("There is no chat template fallback for %s",
tokenizer.name_or_path)
return chat_template
def _resolve_chat_template_content_format( def _resolve_chat_template_content_format(
model_config: ModelConfig,
chat_template: Optional[str], chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]], tools: Optional[list[dict[str, Any]]],
given_format: ChatTemplateContentFormatOption, given_format: ChatTemplateContentFormatOption,
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
*,
trust_remote_code: bool,
) -> _ChatTemplateContentFormat: ) -> _ChatTemplateContentFormat:
if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
hf_chat_template = resolve_hf_chat_template( hf_chat_template = resolve_hf_chat_template(
model_config,
tokenizer, tokenizer,
chat_template=chat_template, chat_template=chat_template,
trust_remote_code=trust_remote_code,
tools=tools, tools=tools,
) )
else: else:
@ -413,19 +428,18 @@ def _log_chat_template_content_format(
def resolve_chat_template_content_format( def resolve_chat_template_content_format(
model_config: ModelConfig,
chat_template: Optional[str], chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]], tools: Optional[list[dict[str, Any]]],
given_format: ChatTemplateContentFormatOption, given_format: ChatTemplateContentFormatOption,
tokenizer: AnyTokenizer, tokenizer: AnyTokenizer,
*,
trust_remote_code: bool = False,
) -> _ChatTemplateContentFormat: ) -> _ChatTemplateContentFormat:
detected_format = _resolve_chat_template_content_format( detected_format = _resolve_chat_template_content_format(
model_config,
chat_template, chat_template,
tools, tools,
given_format, given_format,
tokenizer, tokenizer,
trust_remote_code=trust_remote_code,
) )
_log_chat_template_content_format( _log_chat_template_content_format(
@ -1177,20 +1191,20 @@ def parse_chat_messages_futures(
def apply_hf_chat_template( def apply_hf_chat_template(
model_config: ModelConfig,
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
conversation: list[ConversationMessage], conversation: list[ConversationMessage],
chat_template: Optional[str], chat_template: Optional[str],
tools: Optional[list[dict[str, Any]]], tools: Optional[list[dict[str, Any]]],
*, *,
trust_remote_code: bool = False,
tokenize: bool = False, # Different from HF's default tokenize: bool = False, # Different from HF's default
**kwargs: Any, **kwargs: Any,
) -> str: ) -> str:
hf_chat_template = resolve_hf_chat_template( hf_chat_template = resolve_hf_chat_template(
model_config,
tokenizer, tokenizer,
chat_template=chat_template, chat_template=chat_template,
tools=tools, tools=tools,
trust_remote_code=trust_remote_code,
) )
if hf_chat_template is None: if hf_chat_template is None:

View File

@ -726,11 +726,11 @@ class LLM:
tokenizer = self.get_tokenizer(lora_request) tokenizer = self.get_tokenizer(lora_request)
model_config = self.llm_engine.get_model_config() model_config = self.llm_engine.get_model_config()
resolved_content_format = resolve_chat_template_content_format( resolved_content_format = resolve_chat_template_content_format(
model_config,
chat_template, chat_template,
tools, tools,
chat_template_content_format, chat_template_content_format,
tokenizer, tokenizer,
trust_remote_code=model_config.trust_remote_code,
) )
_chat_template_kwargs: dict[str, Any] = dict( _chat_template_kwargs: dict[str, Any] = dict(
@ -762,8 +762,8 @@ class LLM:
) )
else: else:
prompt_str = apply_hf_chat_template( prompt_str = apply_hf_chat_template(
model_config,
tokenizer, tokenizer,
trust_remote_code=model_config.trust_remote_code,
conversation=conversation, conversation=conversation,
**_chat_template_kwargs, **_chat_template_kwargs,
) )

View File

@ -937,10 +937,11 @@ async def init_app_state(
chat_template=resolved_chat_template) chat_template=resolved_chat_template)
else: else:
hf_chat_template = resolve_hf_chat_template( hf_chat_template = resolve_hf_chat_template(
vllm_config.model_config,
tokenizer, tokenizer,
chat_template=None, chat_template=None,
tools=None, tools=None,
trust_remote_code=model_config.trust_remote_code) )
if hf_chat_template != resolved_chat_template: if hf_chat_template != resolved_chat_template:
logger.warning( logger.warning(

View File

@ -394,11 +394,11 @@ class OpenAIServing:
model_config = self.model_config model_config = self.model_config
resolved_content_format = resolve_chat_template_content_format( resolved_content_format = resolve_chat_template_content_format(
model_config,
chat_template, chat_template,
tool_dicts, tool_dicts,
chat_template_content_format, chat_template_content_format,
tokenizer, tokenizer,
trust_remote_code=model_config.trust_remote_code,
) )
conversation, mm_data_future = parse_chat_messages_futures( conversation, mm_data_future = parse_chat_messages_futures(
messages, messages,
@ -425,8 +425,8 @@ class OpenAIServing:
) )
else: else:
request_prompt = apply_hf_chat_template( request_prompt = apply_hf_chat_template(
model_config,
tokenizer, tokenizer,
trust_remote_code=model_config.trust_remote_code,
conversation=conversation, conversation=conversation,
**_chat_template_kwargs, **_chat_template_kwargs,
) )

View File

@ -0,0 +1,4 @@
# SPDX-License-Identifier: Apache-2.0
from .registry import get_chat_template_fallback_path
__all__ = ["get_chat_template_fallback_path"]

View File

@ -0,0 +1,59 @@
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path
from typing import Callable, Optional, Union
from vllm.logger import init_logger
logger = init_logger(__file__)
CHAT_TEMPLATES_DIR = Path(__file__).parent
ChatTemplatePath = Union[Path, Callable[[str], Optional[Path]]]
def _get_qwen_chat_template_fallback(
tokenizer_name_or_path: str) -> Optional[Path]:
if tokenizer_name_or_path.endswith("-Chat"):
return CHAT_TEMPLATES_DIR / "template_chatml.jinja"
return CHAT_TEMPLATES_DIR / "template_basic.jinja"
# yapf: disable
_MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: dict[str, ChatTemplatePath] = {
"blip-2": CHAT_TEMPLATES_DIR / "template_blip2.jinja",
"chameleon": CHAT_TEMPLATES_DIR / "template_basic.jinja",
"deepseek_vl_v2": CHAT_TEMPLATES_DIR / "template_deepseek_vl2.jinja",
"florence2": CHAT_TEMPLATES_DIR / "template_basic.jinja",
"fuyu": CHAT_TEMPLATES_DIR / "template_fuyu.jinja",
"paligemma": CHAT_TEMPLATES_DIR / "template_basic.jinja",
"qwen": _get_qwen_chat_template_fallback,
}
# yapf: enable
def register_chat_template_fallback_path(
model_type: str,
chat_template: ChatTemplatePath,
) -> None:
if model_type in _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK:
logger.warning(
"Model type %s already has a chat template registered. "
"It will be overwritten by the new chat template %s.", model_type,
chat_template)
_MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK[model_type] = chat_template
def get_chat_template_fallback_path(
model_type: str,
tokenizer_name_or_path: str,
) -> Optional[Path]:
chat_template = _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK.get(model_type)
if callable(chat_template):
chat_template = chat_template(tokenizer_name_or_path)
if chat_template is None:
return None
return chat_template