From b801bf30d724e63e751797a18a0a2d76b228ad6d Mon Sep 17 00:00:00 2001 From: ShriKode Date: Sat, 28 Jun 2025 22:21:17 +0000 Subject: [PATCH] iterate Signed-off-by: ShriKode --- vllm/model_executor/models/gemma3n_mm.py | 286 +++++------------------ vllm/model_executor/models/qwen_vl.py | 3 +- 2 files changed, 66 insertions(+), 223 deletions(-) diff --git a/vllm/model_executor/models/gemma3n_mm.py b/vllm/model_executor/models/gemma3n_mm.py index 039e5a37f3..931a2dcbb3 100644 --- a/vllm/model_executor/models/gemma3n_mm.py +++ b/vllm/model_executor/models/gemma3n_mm.py @@ -1,59 +1,58 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import math from collections.abc import Iterable, Mapping, Sequence -from typing import Any, Literal, Optional, TypedDict +from typing import Any, Optional, TypedDict, Union import torch from torch import nn -from transformers import BatchFeature, Gemma3nConfig, Gemma3nProcessor -from transformers.models.gemma3n.processing_gemma3n import Gemma3nProcessorKwargs -from transformers import AutoModel +from transformers import AutoModel, BatchFeature +from transformers.models.gemma3n import (Gemma3nAudioConfig, + Gemma3nAudioFeatureExtractor, + Gemma3nConfig, Gemma3nProcessor, + Gemma3nTextConfig, + Gemma3nVisionConfig) +from transformers.models.siglip import SiglipImageProcessorFast -import vllm.envs as envs from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import RowParallelLinear from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - MergedColumnParallelLinear, - QKVParallelLinear, - ReplicatedLinear, - RowParallelLinear) from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, MultiModalKwargs) -from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, - MultiModalDataItems) +from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems # yapf: disable from vllm.multimodal.processing import (BaseMultiModalProcessor, BaseProcessingInfo, BoundPromptUpdate, PlaceholderFeaturesInfo, PromptReplacement, PromptTargetMatch, - PromptUpdate, PromptUpdateDetails, - find_mm_placeholders, + PromptUpdate, find_mm_placeholders, replace_token_matches) # yapf: enable from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.sequence import IntermediateTensors -from .interfaces import (MultiModalEmbeddings, SupportsLoRA, - SupportsMultiModal, SupportsPP) -from .siglip import SiglipVisionModel -from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn, +from .interfaces import MultiModalEmbeddings, SupportsMultiModal +from .utils import (AutoWeightsLoader, WeightsMapper, init_vllm_registered_model, maybe_prefix, merge_multimodal_embeddings) logger = init_logger(__name__) +# This should be based on model config but we hardcode them for now. +TOKENS_PER_IMAGE = 256 +TOKENS_PER_AUDIO = 188 + class Gemma3nImagePixelInputs(TypedDict): pixel_values: torch.Tensor """Shape: `(batch_size * num_images, num_channels, height, width)`""" + class Gemma3nAudioInputs(TypedDict): input_features: torch.Tensor """Shape: `(batch_size * num_audio, seq_length, num_features)`""" @@ -64,7 +63,7 @@ class Gemma3nAudioInputs(TypedDict): Gemma3nImageInputs = Gemma3nImagePixelInputs -class Gemma3ProcessingInfo(BaseProcessingInfo): +class Gemma3nProcessingInfo(BaseProcessingInfo): def get_hf_config(self): return self.ctx.get_hf_config(Gemma3nConfig) @@ -73,171 +72,26 @@ class Gemma3ProcessingInfo(BaseProcessingInfo): return self.ctx.get_hf_processor(Gemma3nProcessor, **kwargs) def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: - return {"image": None} + return {"image": None, "audio": None} - def _resolve_image_kwargs( - self, - processor: Gemma3Processor, - keys: set[str], - ) -> dict[str, Any]: - image_processor = processor.image_processor - kwargs = processor._merge_kwargs( - Gemma3ProcessorKwargs, - tokenizer_init_kwargs=processor.tokenizer.init_kwargs, - ) + def get_max_tokens_per_item( + self, seq_len: int, + mm_counts: Mapping[str, int]) -> Optional[Mapping[str, int]]: - images_kwargs = kwargs["images_kwargs"] - - def _resolve_kw(key: str): - val = getattr(image_processor, key) - if val is None: - val = images_kwargs[key] - - return val - - return {k: _resolve_kw(k) for k in keys} - - def get_num_crops( - self, - *, - image_width: int, - image_height: int, - processor: Optional[Gemma3Processor], - ) -> int: - if processor is None: - processor = self.get_hf_processor() - - images_kwargs = self._resolve_image_kwargs( - processor, { - "do_pan_and_scan", "pan_and_scan_min_crop_size", - "pan_and_scan_max_num_crops", - "pan_and_scan_min_ratio_to_activate" - }) - - do_pan_and_scan = images_kwargs["do_pan_and_scan"] - pan_and_scan_min_crop_size = images_kwargs[ - "pan_and_scan_min_crop_size"] - pan_and_scan_max_num_crops = images_kwargs[ - "pan_and_scan_max_num_crops"] - pan_and_scan_min_ratio_to_activate = images_kwargs[ - "pan_and_scan_min_ratio_to_activate"] - - if not do_pan_and_scan: - return 0 - - if envs.VLLM_USE_V1: - logger.warning_once( - "`do_pan_and_scan=True` has suboptimal results on V1 " - "because of the simplified attention pattern being used.") - - # Based on Gemma3ImageProcessor.pan_and_scan - if image_width >= image_height: - if image_width / image_height < pan_and_scan_min_ratio_to_activate: - return 0 - - num_crops_w = min( - int(math.floor(image_width / pan_and_scan_min_crop_size)), - int(math.floor(image_width / image_height + 0.5)), - ) - - num_crops_w = max(2, num_crops_w) - num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w) - num_crops_h = 1 - else: - if image_height / image_width < pan_and_scan_min_ratio_to_activate: - return 0 - - num_crops_h = min( - int(math.floor(image_height / pan_and_scan_min_crop_size)), - int(math.floor(image_height / image_width + 0.5)), - ) - - num_crops_h = max(2, num_crops_h) - num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h) - num_crops_w = 1 - - crop_size_w = int(math.ceil(image_width / num_crops_w)) - crop_size_h = int(math.ceil(image_height / num_crops_h)) - - if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size: - return 0 - - return num_crops_w * num_crops_h - - def get_image_repl( - self, - *, - image_width: int, - image_height: int, - processor: Optional[Gemma3Processor], - ) -> PromptUpdateDetails[str]: - if processor is None: - processor = self.get_hf_processor() - - boi_token = processor.boi_token - - num_crops = self.get_num_crops( - image_width=image_width, - image_height=image_height, - processor=processor, - ) - - if num_crops == 0: - image_text = boi_token - else: - crops_image_tokens = " ".join(boi_token for _ in range(num_crops)) - image_text = ( - f"Here is the original image {boi_token} and here are some " - f"crops to help you see better {crops_image_tokens}") - - repl_full = image_text.replace(boi_token, - processor.full_image_sequence) - - tokenizer = processor.tokenizer - vocab = tokenizer.get_vocab() - image_token_id = vocab[tokenizer.image_token] - - return PromptUpdateDetails.select_token_id(repl_full, image_token_id) - - def get_num_image_tokens( - self, - *, - image_width: int, - image_height: int, - processor: Optional[Gemma3Processor], - ) -> int: - if processor is None: - processor = self.get_hf_processor() - - num_crops = self.get_num_crops( - image_width=image_width, - image_height=image_height, - processor=processor, - ) - image_seq_len = processor.image_seq_length - - return (num_crops + 1) * image_seq_len - - def get_image_size_with_most_features(self) -> ImageSize: - processor = self.get_hf_processor() - - images_kwargs = self._resolve_image_kwargs( - processor, {"pan_and_scan_max_num_crops"}) - max_num_crops = images_kwargs["pan_and_scan_max_num_crops"] - - # Result in the max possible feature size (h:w = max_num_crops:1) - return ImageSize(height=50 * max_num_crops, width=50) + return {"image": TOKENS_PER_IMAGE, "audio": TOKENS_PER_AUDIO} -class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]): +class Gemma3nDummyInputsBuilder(BaseDummyInputsBuilder[Gemma3nProcessingInfo]): def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: num_images = mm_counts.get("image", 0) + num_audios = mm_counts.get("audio", 0) processor = self.info.get_hf_processor() - image_token = processor.boi_token + image_token = processor.image_token + audio_token = processor.audio_token - return image_token * num_images + return image_token * num_images + audio_token * num_audios def get_dummy_mm_data( self, @@ -245,19 +99,26 @@ class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]): mm_counts: Mapping[str, int], ) -> MultiModalDataDict: num_images = mm_counts.get("image", 0) - - target_width, target_height = \ - self.info.get_image_size_with_most_features() + num_audios = mm_counts.get("audio", 0) + processor = self.info.get_hf_processor() + feature_extractor: Gemma3nAudioFeatureExtractor = processor.feature_extractor # noqa: E501 + audio_len = feature_extractor.max_length + image_processor: SiglipImageProcessorFast = processor.image_processor + img_width = image_processor.size.get("width", 224) + img_height = image_processor.size.get("width", 224) return { "image": - self._get_dummy_images(width=target_width, - height=target_height, - num_images=num_images) + self._get_dummy_images(width=img_width, + height=img_height, + num_images=num_images), + "audio": + self._get_dummy_audios(length=audio_len, num_audios=num_audios) } -class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): +class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo] + ): def _call_hf_processor( self, @@ -270,27 +131,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): mm_data, mm_kwargs, ) - - # HF processor pops the `num_crops` kwarg, which is needed by vLLM - if (images := mm_data.get("images")) is not None: - parsed_images = (self._get_data_parser().parse_mm_data({ - "image": - images - }).get_items("image", ImageProcessorItems)) - image_sizes = [ - parsed_images.get_image_size(i) - for i in range(len(parsed_images)) - ] - hf_processor = self.info.get_hf_processor(**mm_kwargs) - - num_crops = [ - self.info.get_num_crops(image_width=size.width, - image_height=size.height, - processor=hf_processor) - for size in image_sizes - ] - processed_outputs["num_crops"] = torch.tensor(num_crops) - return processed_outputs def _get_mm_fields_config( @@ -298,12 +138,11 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): hf_inputs: BatchFeature, hf_processor_mm_kwargs: Mapping[str, object], ) -> Mapping[str, MultiModalFieldConfig]: - num_crops = hf_inputs.get("num_crops", torch.empty(0)) return dict( - pixel_values=MultiModalFieldConfig.flat_from_sizes( - "image", num_crops + 1), - num_crops=MultiModalFieldConfig.batched("image"), + pixel_values=MultiModalFieldConfig.batched("image"), + input_features=MultiModalFieldConfig.batched("audio"), + input_features_mask=MultiModalFieldConfig.batched("audio"), ) def _get_prompt_updates( @@ -421,7 +260,8 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): class Gemma3nMultimodalEmbedder(nn.Module): - """Embeds token ids or soft tokens for multimodal content into language model space.""" + """Embeds token ids or soft tokens for multimodal content into language + model space.""" def __init__( self, @@ -436,7 +276,6 @@ class Gemma3nMultimodalEmbedder(nn.Module): self.vocab_size = multimodal_config.vocab_size self.text_hidden_size = text_config.hidden_size - self.embedding = VocabParallelEmbedding( self.vocab_size, self.multimodal_hidden_size, @@ -478,11 +317,10 @@ class Gemma3nMultimodalEmbedder(nn.Module): Returns: A torch.Tensor of embeddings with shape `[batch_size, seq_len, self.config.text_config.hidden_size]`. - """ + """ # noqa: E501 if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( - "You must specify exactly one of input_ids or inputs_embeds" - ) + "You must specify exactly one of input_ids or inputs_embeds") if inputs_embeds is not None: emb_norm = self.soft_embedding_norm(inputs_embeds) @@ -495,8 +333,8 @@ class Gemma3nMultimodalEmbedder(nn.Module): @MULTIMODAL_REGISTRY.register_processor(Gemma3MultiModalProcessor, - info=Gemma3ProcessingInfo, - dummy_inputs=Gemma3DummyInputsBuilder) + info=Gemma3nProcessingInfo, + dummy_inputs=Gemma3nDummyInputsBuilder) class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal): packed_modules_mapping = { "qkv_proj": [ @@ -532,8 +370,10 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal): self.vision_tower = AutoModel.from_config(config=config.vision_config) self.audio_tower = AutoModel.from_config(config=config.audio_config) - self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config, config.text_config) - self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config, config.text_config) + self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config, + config.text_config) + self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config, + config.text_config) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, @@ -553,9 +393,9 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal): assert self.vision_tower is not None pixel_values = image_input["pixel_values"] - vision_outputs = self.vision_tower( - pixel_values=pixel_values, do_pooling=False, return_dict=True - ).last_hidden_state + vision_outputs = self.vision_tower(pixel_values=pixel_values, + do_pooling=False, + return_dict=True).last_hidden_state vision_outputs = vision_outputs.reshape( vision_outputs.shape[0], self.config.vision_config.hidden_size, @@ -566,14 +406,16 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal): return self.embed_vision(inputs_embeds=vision_outputs) def _process_audio_input( - self, audio_input: Gemma3nAudioInputs, + self, + audio_input: Gemma3nAudioInputs, ) -> tuple[torch.Tensor, torch.Tensor]: assert self.audio_tower is not None input_features = audio_input["input_features"] input_features_mask = audio_input["input_features_mask"] - audio_outputs, audio_mask = self.audio_tower(input_features, input_features_mask) + audio_outputs, audio_mask = self.audio_tower(input_features, + input_features_mask) return self.embed_audio(inputs_embeds=audio_outputs), audio_mask - + def get_language_model(self) -> torch.nn.Module: return self.language_model diff --git a/vllm/model_executor/models/qwen_vl.py b/vllm/model_executor/models/qwen_vl.py index fc29785af9..fe20de10da 100644 --- a/vllm/model_executor/models/qwen_vl.py +++ b/vllm/model_executor/models/qwen_vl.py @@ -617,7 +617,8 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]): ) -> Mapping[str, MultiModalFieldConfig]: return dict( pixel_values=MultiModalFieldConfig.batched("image"), - image_embeds=MultiModalFieldConfig.batched("image"), + input_features=MultiModalFieldConfig.batched("audio"), + input_features_mask=MultiModalFieldConfig.batched("audio"), ) def _get_prompt_updates(