diff --git a/docs/source/serving/multimodal_inputs.md b/docs/source/serving/multimodal_inputs.md index bcaa4f9b96..bb2997f008 100644 --- a/docs/source/serving/multimodal_inputs.md +++ b/docs/source/serving/multimodal_inputs.md @@ -213,10 +213,13 @@ Our OpenAI-compatible server accepts multi-modal data via the [Chat Completions :::{important} A chat template is **required** to use Chat Completions API. +For HF format models, the default chat template is defined inside `chat_template.json` or `tokenizer_config.json`. -Although most models come with a chat template, for others you have to define one yourself. -The chat template can be inferred based on the documentation on the model's HuggingFace repo. -For example, DeepSeek-VL2 requires a chat template that can be found here: +If no default chat template is available, we will first look for a built-in fallback in . +If no fallback is available, an error is raised and you have to provide the chat template manually via the `--chat-template` argument. + +For certain models, we provide alternative chat templates inside . +For example, VLM2Vec uses which is different from the default one for Phi-3-Vision. ::: ### Image Inputs diff --git a/examples/template_florence2.jinja b/examples/template_florence2.jinja deleted file mode 100644 index 3fa2cccc24..0000000000 --- a/examples/template_florence2.jinja +++ /dev/null @@ -1,3 +0,0 @@ -{%- for message in messages -%} - {{- message['content'] -}} -{%- endfor -%} diff --git a/examples/template_paligemma.jinja b/examples/template_paligemma.jinja deleted file mode 100644 index 3fa2cccc24..0000000000 --- a/examples/template_paligemma.jinja +++ /dev/null @@ -1,3 +0,0 @@ -{%- for message in messages -%} - {{- message['content'] -}} -{%- endfor -%} diff --git a/examples/template_qwen_vl.jinja b/examples/template_qwen_vl.jinja deleted file mode 100644 index 3fa2cccc24..0000000000 --- a/examples/template_qwen_vl.jinja +++ /dev/null @@ -1,3 +0,0 @@ -{%- for message in messages -%} - {{- message['content'] -}} -{%- endfor -%} diff --git a/tests/entrypoints/openai/test_chat_template.py b/tests/entrypoints/openai/test_chat_template.py index 78e40eeecd..48ede50e98 100644 --- a/tests/entrypoints/openai/test_chat_template.py +++ b/tests/entrypoints/openai/test_chat_template.py @@ -2,11 +2,13 @@ import pytest +from vllm.config import ModelConfig from vllm.entrypoints.chat_utils import (apply_hf_chat_template, load_chat_template) from vllm.entrypoints.openai.protocol import ChatCompletionRequest from vllm.transformers_utils.tokenizer import get_tokenizer +from ...models.registry import HF_EXAMPLE_MODELS from ...utils import VLLM_PATH chatml_jinja_path = VLLM_PATH / "examples/template_chatml.jinja" @@ -91,8 +93,22 @@ def test_no_load_chat_template_literallike(): MODEL_TEMPLATE_GENERATON_OUTPUT) def test_get_gen_prompt(model, template, add_generation_prompt, continue_final_message, expected_output): + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + + model_config = ModelConfig( + model, + tokenizer=model_info.tokenizer or model, + tokenizer_mode=model_info.tokenizer_mode, + trust_remote_code=model_info.trust_remote_code, + hf_overrides=model_info.hf_overrides, + ) + # Initialize the tokenizer - tokenizer = get_tokenizer(tokenizer_name=model) + tokenizer = get_tokenizer( + tokenizer_name=model_config.tokenizer, + trust_remote_code=model_config.trust_remote_code, + ) template_content = load_chat_template(chat_template=template) # Create a mock request object using keyword arguments @@ -106,8 +122,8 @@ def test_get_gen_prompt(model, template, add_generation_prompt, # Call the function and get the result result = apply_hf_chat_template( + model_config, tokenizer, - trust_remote_code=True, conversation=mock_request.messages, chat_template=mock_request.chat_template or template_content, tools=None, diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 1de30f0ac0..bcb25ed990 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -4,8 +4,6 @@ import warnings from typing import Optional import pytest -from packaging.version import Version -from transformers import __version__ as TRANSFORMERS_VERSION from vllm.assets.image import ImageAsset from vllm.config import ModelConfig @@ -19,6 +17,7 @@ from vllm.multimodal import MultiModalDataDict from vllm.multimodal.utils import encode_image_base64 from vllm.transformers_utils.tokenizer_group import TokenizerGroup +from ..models.registry import HF_EXAMPLE_MODELS from ..utils import VLLM_PATH EXAMPLES_DIR = VLLM_PATH / "examples" @@ -772,6 +771,7 @@ def test_multimodal_image_parsing_matches_hf(model, image_url): enable_lora=False, max_num_seqs=5, max_input_length=None, + trust_remote_code=model_config.trust_remote_code, ) tokenizer = tokenizer_group.tokenizer @@ -793,8 +793,8 @@ def test_multimodal_image_parsing_matches_hf(model, image_url): ) vllm_result = apply_hf_chat_template( + model_config, tokenizer, - trust_remote_code=model_config.trust_remote_code, conversation=conversation, chat_template=None, tools=None, @@ -813,6 +813,16 @@ def test_multimodal_image_parsing_matches_hf(model, image_url): @pytest.mark.parametrize("use_tools", [True, False]) def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): """checks that chat_template is a dict type for HF models.""" + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + + model_config = ModelConfig( + model, + tokenizer=model_info.tokenizer or model, + tokenizer_mode=model_info.tokenizer_mode, + trust_remote_code=model_info.trust_remote_code, + hf_overrides=model_info.hf_overrides, + ) # Build the tokenizer group and grab the underlying tokenizer tokenizer_group = TokenizerGroup( @@ -820,6 +830,7 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): enable_lora=False, max_num_seqs=5, max_input_length=None, + trust_remote_code=model_config.trust_remote_code, ) tokenizer = tokenizer_group.tokenizer @@ -834,10 +845,10 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): # Test detecting the tokenizer's chat_template chat_template = resolve_hf_chat_template( + model_config, tokenizer, chat_template=None, tools=tools, - trust_remote_code=True, ) assert isinstance(chat_template, str) @@ -857,24 +868,32 @@ def test_resolve_hf_chat_template(sample_json_schema, model, use_tools): ) # yapf: enable def test_resolve_content_format_hf_defined(model, expected_format): - if model == QWEN25VL_MODEL_ID and Version(TRANSFORMERS_VERSION) < Version( - "4.49.0"): - pytest.skip("Qwen2.5-VL requires transformers>=4.49.0") + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + + model_config = ModelConfig( + model, + tokenizer=model_info.tokenizer or model, + tokenizer_mode=model_info.tokenizer_mode, + trust_remote_code=model_info.trust_remote_code, + hf_overrides=model_info.hf_overrides, + ) tokenizer_group = TokenizerGroup( model, enable_lora=False, max_num_seqs=5, max_input_length=None, + trust_remote_code=model_config.trust_remote_code, ) tokenizer = tokenizer_group.tokenizer # Test detecting the tokenizer's chat_template chat_template = resolve_hf_chat_template( + model_config, tokenizer, chat_template=None, tools=None, - trust_remote_code=True, ) assert isinstance(chat_template, str) @@ -884,11 +903,70 @@ def test_resolve_content_format_hf_defined(model, expected_format): print(_try_extract_ast(chat_template)) resolved_format = resolve_chat_template_content_format( + model_config, + None, # Test detecting the tokenizer's chat_template + None, + "auto", + tokenizer, + ) + + assert resolved_format == expected_format + + +# yapf: disable +@pytest.mark.parametrize( + ("model", "expected_format"), + [("Salesforce/blip2-opt-2.7b", "string"), + ("facebook/chameleon-7b", "string"), + ("deepseek-ai/deepseek-vl2-tiny", "string"), + ("microsoft/Florence-2-base", "string"), + ("adept/fuyu-8b", "string"), + ("google/paligemma-3b-mix-224", "string"), + ("Qwen/Qwen-VL", "string"), + ("Qwen/Qwen-VL-Chat", "string")], +) +# yapf: enable +def test_resolve_content_format_fallbacks(model, expected_format): + model_info = HF_EXAMPLE_MODELS.find_hf_info(model) + model_info.check_available_online(on_fail="skip") + + model_config = ModelConfig( + model, + tokenizer=model_info.tokenizer or model, + tokenizer_mode=model_info.tokenizer_mode, + trust_remote_code=model_info.trust_remote_code, + hf_overrides=model_info.hf_overrides, + ) + + tokenizer_group = TokenizerGroup( + model_config.tokenizer, + enable_lora=False, + max_num_seqs=5, + max_input_length=None, + trust_remote_code=model_config.trust_remote_code, + ) + tokenizer = tokenizer_group.tokenizer + + # Test detecting the tokenizer's chat_template + chat_template = resolve_hf_chat_template( + model_config, + tokenizer, + chat_template=None, + tools=None, + ) + assert isinstance(chat_template, str) + + print("[TEXT]") + print(chat_template) + print("[AST]") + print(_try_extract_ast(chat_template)) + + resolved_format = resolve_chat_template_content_format( + model_config, None, # Test detecting the tokenizer's chat_template None, "auto", tokenizer, - trust_remote_code=True, ) assert resolved_format == expected_format @@ -899,22 +977,14 @@ def test_resolve_content_format_hf_defined(model, expected_format): ("template_path", "expected_format"), [("template_alpaca.jinja", "string"), ("template_baichuan.jinja", "string"), - ("template_blip2.jinja", "string"), - ("template_chameleon.jinja", "string"), ("template_chatglm.jinja", "string"), ("template_chatglm2.jinja", "string"), ("template_chatml.jinja", "string"), - ("template_deepseek_vl2.jinja", "string"), ("template_dse_qwen2_vl.jinja", "openai"), ("template_falcon_180b.jinja", "string"), ("template_falcon.jinja", "string"), - ("template_florence2.jinja", "string"), - ("template_fuyu.jinja", "string"), ("template_inkbot.jinja", "string"), - ("template_paligemma.jinja", "string"), ("template_teleflm.jinja", "string"), - ("template_qwen_vl.jinja", "string"), - ("template_qwen_vl_chat.jinja", "string"), ("template_vlm2vec.jinja", "openai"), ("tool_chat_template_granite_20b_fc.jinja", "string"), ("tool_chat_template_hermes.jinja", "string"), @@ -926,11 +996,18 @@ def test_resolve_content_format_hf_defined(model, expected_format): ) # yapf: enable def test_resolve_content_format_examples(template_path, expected_format): + model_config = ModelConfig( + PHI3V_MODEL_ID, # Dummy + tokenizer=PHI3V_MODEL_ID, # Dummy + trust_remote_code=True, + ) + tokenizer_group = TokenizerGroup( - PHI3V_MODEL_ID, + PHI3V_MODEL_ID, # Dummy enable_lora=False, max_num_seqs=5, max_input_length=None, + trust_remote_code=model_config.trust_remote_code, ) dummy_tokenizer = tokenizer_group.tokenizer dummy_tokenizer.chat_template = None @@ -944,11 +1021,11 @@ def test_resolve_content_format_examples(template_path, expected_format): print(_try_extract_ast(chat_template)) resolved_format = resolve_chat_template_content_format( + model_config, chat_template, None, "auto", dummy_tokenizer, - trust_remote_code=True, ) assert resolved_format == expected_format diff --git a/tests/models/registry.py b/tests/models/registry.py index cd5e1dab0a..a1f2edac02 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -182,7 +182,9 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { "JAISLMHeadModel": _HfExamplesInfo("inceptionai/jais-13b-chat"), "JambaForCausalLM": _HfExamplesInfo("ai21labs/AI21-Jamba-1.5-Mini", extras={"tiny": "ai21labs/Jamba-tiny-dev"}), # noqa: E501 - "LlamaForCausalLM": _HfExamplesInfo("meta-llama/Llama-3.2-1B-Instruct"), + "LlamaForCausalLM": _HfExamplesInfo("meta-llama/Llama-3.2-1B-Instruct", + extras={"guard": "meta-llama/Llama-Guard-3-1B", # noqa: E501 + "hermes": "NousResearch/Hermes-3-Llama-3.1-8B"}), # noqa: E501 "LLaMAForCausalLM": _HfExamplesInfo("decapoda-research/llama-7b-hf", is_available_online=False), "MambaForCausalLM": _HfExamplesInfo("state-spaces/mamba-130m-hf"), @@ -378,7 +380,7 @@ _MULTIMODAL_EXAMPLE_MODELS = { # Therefore, we borrow the BartTokenizer from the original Bart model "Florence2ForConditionalGeneration": _HfExamplesInfo("microsoft/Florence-2-base", # noqa: E501 tokenizer="Isotr0py/Florence-2-tokenizer", - trust_remote_code=True), # noqa: E501 + trust_remote_code=True,), # noqa: E501 "MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501 "WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501 } diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 23dded7f22..38fe985721 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -38,6 +38,10 @@ from vllm.config import ModelConfig from vllm.logger import init_logger from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalDataDict from vllm.multimodal.utils import MediaConnector +# yapf: disable +from vllm.transformers_utils.chat_templates import ( + get_chat_template_fallback_path) +# yapf: enable from vllm.transformers_utils.processor import cached_get_processor from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer @@ -325,11 +329,10 @@ def resolve_mistral_chat_template( return None def resolve_hf_chat_template( + model_config: ModelConfig, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], chat_template: Optional[str], tools: Optional[list[dict[str, Any]]], - *, - trust_remote_code: bool, ) -> Optional[str]: # 1st priority: The given chat template if chat_template is not None: @@ -342,7 +345,7 @@ def resolve_hf_chat_template( tokenizer.name_or_path, processor_cls=(PreTrainedTokenizer, PreTrainedTokenizerFast, ProcessorMixin), - trust_remote_code=trust_remote_code, + trust_remote_code=model_config.trust_remote_code, ) if isinstance(processor, ProcessorMixin) and \ processor.chat_template is not None: @@ -358,22 +361,34 @@ def resolve_hf_chat_template( logger.debug("Failed to load AutoTokenizer chat template for %s", tokenizer.name_or_path, exc_info=True) - return None + # 4th priority: Predefined fallbacks + path = get_chat_template_fallback_path( + model_type=model_config.hf_config.model_type, + tokenizer_name_or_path=model_config.tokenizer, + ) + if path is not None: + logger.info("Loading chat template fallback for %s as there isn't one " + "defined on HF Hub.", tokenizer.name_or_path) + chat_template = load_chat_template(path) + else: + logger.debug("There is no chat template fallback for %s", + tokenizer.name_or_path) + + return chat_template def _resolve_chat_template_content_format( + model_config: ModelConfig, chat_template: Optional[str], tools: Optional[list[dict[str, Any]]], given_format: ChatTemplateContentFormatOption, tokenizer: AnyTokenizer, - *, - trust_remote_code: bool, ) -> _ChatTemplateContentFormat: if isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): hf_chat_template = resolve_hf_chat_template( + model_config, tokenizer, chat_template=chat_template, - trust_remote_code=trust_remote_code, tools=tools, ) else: @@ -413,19 +428,18 @@ def _log_chat_template_content_format( def resolve_chat_template_content_format( + model_config: ModelConfig, chat_template: Optional[str], tools: Optional[list[dict[str, Any]]], given_format: ChatTemplateContentFormatOption, tokenizer: AnyTokenizer, - *, - trust_remote_code: bool = False, ) -> _ChatTemplateContentFormat: detected_format = _resolve_chat_template_content_format( + model_config, chat_template, tools, given_format, tokenizer, - trust_remote_code=trust_remote_code, ) _log_chat_template_content_format( @@ -1177,20 +1191,20 @@ def parse_chat_messages_futures( def apply_hf_chat_template( + model_config: ModelConfig, tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], conversation: list[ConversationMessage], chat_template: Optional[str], tools: Optional[list[dict[str, Any]]], *, - trust_remote_code: bool = False, tokenize: bool = False, # Different from HF's default **kwargs: Any, ) -> str: hf_chat_template = resolve_hf_chat_template( + model_config, tokenizer, chat_template=chat_template, tools=tools, - trust_remote_code=trust_remote_code, ) if hf_chat_template is None: diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index a04ab885a7..72ad79bd2d 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -726,11 +726,11 @@ class LLM: tokenizer = self.get_tokenizer(lora_request) model_config = self.llm_engine.get_model_config() resolved_content_format = resolve_chat_template_content_format( + model_config, chat_template, tools, chat_template_content_format, tokenizer, - trust_remote_code=model_config.trust_remote_code, ) _chat_template_kwargs: dict[str, Any] = dict( @@ -762,8 +762,8 @@ class LLM: ) else: prompt_str = apply_hf_chat_template( + model_config, tokenizer, - trust_remote_code=model_config.trust_remote_code, conversation=conversation, **_chat_template_kwargs, ) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index af132481b1..e034eacb24 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -937,10 +937,11 @@ async def init_app_state( chat_template=resolved_chat_template) else: hf_chat_template = resolve_hf_chat_template( + vllm_config.model_config, tokenizer, chat_template=None, tools=None, - trust_remote_code=model_config.trust_remote_code) + ) if hf_chat_template != resolved_chat_template: logger.warning( diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 25069c28a0..bb11650815 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -394,11 +394,11 @@ class OpenAIServing: model_config = self.model_config resolved_content_format = resolve_chat_template_content_format( + model_config, chat_template, tool_dicts, chat_template_content_format, tokenizer, - trust_remote_code=model_config.trust_remote_code, ) conversation, mm_data_future = parse_chat_messages_futures( messages, @@ -425,8 +425,8 @@ class OpenAIServing: ) else: request_prompt = apply_hf_chat_template( + model_config, tokenizer, - trust_remote_code=model_config.trust_remote_code, conversation=conversation, **_chat_template_kwargs, ) diff --git a/vllm/transformers_utils/chat_templates/__init__.py b/vllm/transformers_utils/chat_templates/__init__.py new file mode 100644 index 0000000000..fe2bd3ca41 --- /dev/null +++ b/vllm/transformers_utils/chat_templates/__init__.py @@ -0,0 +1,4 @@ +# SPDX-License-Identifier: Apache-2.0 +from .registry import get_chat_template_fallback_path + +__all__ = ["get_chat_template_fallback_path"] diff --git a/vllm/transformers_utils/chat_templates/registry.py b/vllm/transformers_utils/chat_templates/registry.py new file mode 100644 index 0000000000..853fed5d44 --- /dev/null +++ b/vllm/transformers_utils/chat_templates/registry.py @@ -0,0 +1,59 @@ +# SPDX-License-Identifier: Apache-2.0 +from pathlib import Path +from typing import Callable, Optional, Union + +from vllm.logger import init_logger + +logger = init_logger(__file__) + +CHAT_TEMPLATES_DIR = Path(__file__).parent + +ChatTemplatePath = Union[Path, Callable[[str], Optional[Path]]] + + +def _get_qwen_chat_template_fallback( + tokenizer_name_or_path: str) -> Optional[Path]: + if tokenizer_name_or_path.endswith("-Chat"): + return CHAT_TEMPLATES_DIR / "template_chatml.jinja" + + return CHAT_TEMPLATES_DIR / "template_basic.jinja" + + +# yapf: disable +_MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: dict[str, ChatTemplatePath] = { + "blip-2": CHAT_TEMPLATES_DIR / "template_blip2.jinja", + "chameleon": CHAT_TEMPLATES_DIR / "template_basic.jinja", + "deepseek_vl_v2": CHAT_TEMPLATES_DIR / "template_deepseek_vl2.jinja", + "florence2": CHAT_TEMPLATES_DIR / "template_basic.jinja", + "fuyu": CHAT_TEMPLATES_DIR / "template_fuyu.jinja", + "paligemma": CHAT_TEMPLATES_DIR / "template_basic.jinja", + "qwen": _get_qwen_chat_template_fallback, +} +# yapf: enable + + +def register_chat_template_fallback_path( + model_type: str, + chat_template: ChatTemplatePath, +) -> None: + if model_type in _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK: + logger.warning( + "Model type %s already has a chat template registered. " + "It will be overwritten by the new chat template %s.", model_type, + chat_template) + + _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK[model_type] = chat_template + + +def get_chat_template_fallback_path( + model_type: str, + tokenizer_name_or_path: str, +) -> Optional[Path]: + chat_template = _MODEL_TYPE_TO_CHAT_TEMPLATE_FALLBACK.get(model_type) + if callable(chat_template): + chat_template = chat_template(tokenizer_name_or_path) + + if chat_template is None: + return None + + return chat_template diff --git a/examples/template_chameleon.jinja b/vllm/transformers_utils/chat_templates/template_basic.jinja similarity index 100% rename from examples/template_chameleon.jinja rename to vllm/transformers_utils/chat_templates/template_basic.jinja diff --git a/examples/template_blip2.jinja b/vllm/transformers_utils/chat_templates/template_blip2.jinja similarity index 100% rename from examples/template_blip2.jinja rename to vllm/transformers_utils/chat_templates/template_blip2.jinja diff --git a/examples/template_qwen_vl_chat.jinja b/vllm/transformers_utils/chat_templates/template_chatml.jinja similarity index 100% rename from examples/template_qwen_vl_chat.jinja rename to vllm/transformers_utils/chat_templates/template_chatml.jinja diff --git a/examples/template_deepseek_vl2.jinja b/vllm/transformers_utils/chat_templates/template_deepseek_vl2.jinja similarity index 100% rename from examples/template_deepseek_vl2.jinja rename to vllm/transformers_utils/chat_templates/template_deepseek_vl2.jinja diff --git a/examples/template_fuyu.jinja b/vllm/transformers_utils/chat_templates/template_fuyu.jinja similarity index 100% rename from examples/template_fuyu.jinja rename to vllm/transformers_utils/chat_templates/template_fuyu.jinja