Signed-off-by: ShriKode <shrikode@gmail.com>
This commit is contained in:
ShriKode 2025-06-28 22:21:17 +00:00
parent bfd63b1b10
commit b801bf30d7
2 changed files with 66 additions and 223 deletions

View File

@ -1,59 +1,58 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal, Optional, TypedDict from typing import Any, Optional, TypedDict, Union
import torch import torch
from torch import nn from torch import nn
from transformers import BatchFeature, Gemma3nConfig, Gemma3nProcessor from transformers import AutoModel, BatchFeature
from transformers.models.gemma3n.processing_gemma3n import Gemma3nProcessorKwargs from transformers.models.gemma3n import (Gemma3nAudioConfig,
from transformers import AutoModel Gemma3nAudioFeatureExtractor,
Gemma3nConfig, Gemma3nProcessor,
Gemma3nTextConfig,
Gemma3nVisionConfig)
from transformers.models.siglip import SiglipImageProcessorFast
import vllm.envs as envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig, from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs) MultiModalKwargs)
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize, from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems
MultiModalDataItems)
# yapf: disable # yapf: disable
from vllm.multimodal.processing import (BaseMultiModalProcessor, from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, BoundPromptUpdate, BaseProcessingInfo, BoundPromptUpdate,
PlaceholderFeaturesInfo, PlaceholderFeaturesInfo,
PromptReplacement, PromptTargetMatch, PromptReplacement, PromptTargetMatch,
PromptUpdate, PromptUpdateDetails, PromptUpdate, find_mm_placeholders,
find_mm_placeholders,
replace_token_matches) replace_token_matches)
# yapf: enable # yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from .interfaces import (MultiModalEmbeddings, SupportsLoRA, from .interfaces import MultiModalEmbeddings, SupportsMultiModal
SupportsMultiModal, SupportsPP) from .utils import (AutoWeightsLoader, WeightsMapper,
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
init_vllm_registered_model, maybe_prefix, init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings) merge_multimodal_embeddings)
logger = init_logger(__name__) logger = init_logger(__name__)
# This should be based on model config but we hardcode them for now.
TOKENS_PER_IMAGE = 256
TOKENS_PER_AUDIO = 188
class Gemma3nImagePixelInputs(TypedDict): class Gemma3nImagePixelInputs(TypedDict):
pixel_values: torch.Tensor pixel_values: torch.Tensor
"""Shape: `(batch_size * num_images, num_channels, height, width)`""" """Shape: `(batch_size * num_images, num_channels, height, width)`"""
class Gemma3nAudioInputs(TypedDict): class Gemma3nAudioInputs(TypedDict):
input_features: torch.Tensor input_features: torch.Tensor
"""Shape: `(batch_size * num_audio, seq_length, num_features)`""" """Shape: `(batch_size * num_audio, seq_length, num_features)`"""
@ -64,7 +63,7 @@ class Gemma3nAudioInputs(TypedDict):
Gemma3nImageInputs = Gemma3nImagePixelInputs Gemma3nImageInputs = Gemma3nImagePixelInputs
class Gemma3ProcessingInfo(BaseProcessingInfo): class Gemma3nProcessingInfo(BaseProcessingInfo):
def get_hf_config(self): def get_hf_config(self):
return self.ctx.get_hf_config(Gemma3nConfig) return self.ctx.get_hf_config(Gemma3nConfig)
@ -73,171 +72,26 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
return self.ctx.get_hf_processor(Gemma3nProcessor, **kwargs) return self.ctx.get_hf_processor(Gemma3nProcessor, **kwargs)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]: def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None} return {"image": None, "audio": None}
def _resolve_image_kwargs( def get_max_tokens_per_item(
self, self, seq_len: int,
processor: Gemma3Processor, mm_counts: Mapping[str, int]) -> Optional[Mapping[str, int]]:
keys: set[str],
) -> dict[str, Any]:
image_processor = processor.image_processor
kwargs = processor._merge_kwargs(
Gemma3ProcessorKwargs,
tokenizer_init_kwargs=processor.tokenizer.init_kwargs,
)
images_kwargs = kwargs["images_kwargs"] return {"image": TOKENS_PER_IMAGE, "audio": TOKENS_PER_AUDIO}
def _resolve_kw(key: str):
val = getattr(image_processor, key)
if val is None:
val = images_kwargs[key]
return val
return {k: _resolve_kw(k) for k in keys}
def get_num_crops(
self,
*,
image_width: int,
image_height: int,
processor: Optional[Gemma3Processor],
) -> int:
if processor is None:
processor = self.get_hf_processor()
images_kwargs = self._resolve_image_kwargs(
processor, {
"do_pan_and_scan", "pan_and_scan_min_crop_size",
"pan_and_scan_max_num_crops",
"pan_and_scan_min_ratio_to_activate"
})
do_pan_and_scan = images_kwargs["do_pan_and_scan"]
pan_and_scan_min_crop_size = images_kwargs[
"pan_and_scan_min_crop_size"]
pan_and_scan_max_num_crops = images_kwargs[
"pan_and_scan_max_num_crops"]
pan_and_scan_min_ratio_to_activate = images_kwargs[
"pan_and_scan_min_ratio_to_activate"]
if not do_pan_and_scan:
return 0
if envs.VLLM_USE_V1:
logger.warning_once(
"`do_pan_and_scan=True` has suboptimal results on V1 "
"because of the simplified attention pattern being used.")
# Based on Gemma3ImageProcessor.pan_and_scan
if image_width >= image_height:
if image_width / image_height < pan_and_scan_min_ratio_to_activate:
return 0
num_crops_w = min(
int(math.floor(image_width / pan_and_scan_min_crop_size)),
int(math.floor(image_width / image_height + 0.5)),
)
num_crops_w = max(2, num_crops_w)
num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w)
num_crops_h = 1
else:
if image_height / image_width < pan_and_scan_min_ratio_to_activate:
return 0
num_crops_h = min(
int(math.floor(image_height / pan_and_scan_min_crop_size)),
int(math.floor(image_height / image_width + 0.5)),
)
num_crops_h = max(2, num_crops_h)
num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h)
num_crops_w = 1
crop_size_w = int(math.ceil(image_width / num_crops_w))
crop_size_h = int(math.ceil(image_height / num_crops_h))
if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size:
return 0
return num_crops_w * num_crops_h
def get_image_repl(
self,
*,
image_width: int,
image_height: int,
processor: Optional[Gemma3Processor],
) -> PromptUpdateDetails[str]:
if processor is None:
processor = self.get_hf_processor()
boi_token = processor.boi_token
num_crops = self.get_num_crops(
image_width=image_width,
image_height=image_height,
processor=processor,
)
if num_crops == 0:
image_text = boi_token
else:
crops_image_tokens = " ".join(boi_token for _ in range(num_crops))
image_text = (
f"Here is the original image {boi_token} and here are some "
f"crops to help you see better {crops_image_tokens}")
repl_full = image_text.replace(boi_token,
processor.full_image_sequence)
tokenizer = processor.tokenizer
vocab = tokenizer.get_vocab()
image_token_id = vocab[tokenizer.image_token]
return PromptUpdateDetails.select_token_id(repl_full, image_token_id)
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
processor: Optional[Gemma3Processor],
) -> int:
if processor is None:
processor = self.get_hf_processor()
num_crops = self.get_num_crops(
image_width=image_width,
image_height=image_height,
processor=processor,
)
image_seq_len = processor.image_seq_length
return (num_crops + 1) * image_seq_len
def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor()
images_kwargs = self._resolve_image_kwargs(
processor, {"pan_and_scan_max_num_crops"})
max_num_crops = images_kwargs["pan_and_scan_max_num_crops"]
# Result in the max possible feature size (h:w = max_num_crops:1)
return ImageSize(height=50 * max_num_crops, width=50)
class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]): class Gemma3nDummyInputsBuilder(BaseDummyInputsBuilder[Gemma3nProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str: def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
num_audios = mm_counts.get("audio", 0)
processor = self.info.get_hf_processor() processor = self.info.get_hf_processor()
image_token = processor.boi_token image_token = processor.image_token
audio_token = processor.audio_token
return image_token * num_images return image_token * num_images + audio_token * num_audios
def get_dummy_mm_data( def get_dummy_mm_data(
self, self,
@ -245,19 +99,26 @@ class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
mm_counts: Mapping[str, int], mm_counts: Mapping[str, int],
) -> MultiModalDataDict: ) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0) num_images = mm_counts.get("image", 0)
num_audios = mm_counts.get("audio", 0)
target_width, target_height = \ processor = self.info.get_hf_processor()
self.info.get_image_size_with_most_features() feature_extractor: Gemma3nAudioFeatureExtractor = processor.feature_extractor # noqa: E501
audio_len = feature_extractor.max_length
image_processor: SiglipImageProcessorFast = processor.image_processor
img_width = image_processor.size.get("width", 224)
img_height = image_processor.size.get("width", 224)
return { return {
"image": "image":
self._get_dummy_images(width=target_width, self._get_dummy_images(width=img_width,
height=target_height, height=img_height,
num_images=num_images) num_images=num_images),
"audio":
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
} }
class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]): class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
):
def _call_hf_processor( def _call_hf_processor(
self, self,
@ -270,27 +131,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
mm_data, mm_data,
mm_kwargs, mm_kwargs,
) )
# HF processor pops the `num_crops` kwarg, which is needed by vLLM
if (images := mm_data.get("images")) is not None:
parsed_images = (self._get_data_parser().parse_mm_data({
"image":
images
}).get_items("image", ImageProcessorItems))
image_sizes = [
parsed_images.get_image_size(i)
for i in range(len(parsed_images))
]
hf_processor = self.info.get_hf_processor(**mm_kwargs)
num_crops = [
self.info.get_num_crops(image_width=size.width,
image_height=size.height,
processor=hf_processor)
for size in image_sizes
]
processed_outputs["num_crops"] = torch.tensor(num_crops)
return processed_outputs return processed_outputs
def _get_mm_fields_config( def _get_mm_fields_config(
@ -298,12 +138,11 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
hf_inputs: BatchFeature, hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object], hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
num_crops = hf_inputs.get("num_crops", torch.empty(0))
return dict( return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes( pixel_values=MultiModalFieldConfig.batched("image"),
"image", num_crops + 1), input_features=MultiModalFieldConfig.batched("audio"),
num_crops=MultiModalFieldConfig.batched("image"), input_features_mask=MultiModalFieldConfig.batched("audio"),
) )
def _get_prompt_updates( def _get_prompt_updates(
@ -421,7 +260,8 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
class Gemma3nMultimodalEmbedder(nn.Module): class Gemma3nMultimodalEmbedder(nn.Module):
"""Embeds token ids or soft tokens for multimodal content into language model space.""" """Embeds token ids or soft tokens for multimodal content into language
model space."""
def __init__( def __init__(
self, self,
@ -436,7 +276,6 @@ class Gemma3nMultimodalEmbedder(nn.Module):
self.vocab_size = multimodal_config.vocab_size self.vocab_size = multimodal_config.vocab_size
self.text_hidden_size = text_config.hidden_size self.text_hidden_size = text_config.hidden_size
self.embedding = VocabParallelEmbedding( self.embedding = VocabParallelEmbedding(
self.vocab_size, self.vocab_size,
self.multimodal_hidden_size, self.multimodal_hidden_size,
@ -478,11 +317,10 @@ class Gemma3nMultimodalEmbedder(nn.Module):
Returns: Returns:
A torch.Tensor of embeddings with shape `[batch_size, seq_len, self.config.text_config.hidden_size]`. A torch.Tensor of embeddings with shape `[batch_size, seq_len, self.config.text_config.hidden_size]`.
""" """ # noqa: E501
if (input_ids is None) ^ (inputs_embeds is not None): if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError( raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds" "You must specify exactly one of input_ids or inputs_embeds")
)
if inputs_embeds is not None: if inputs_embeds is not None:
emb_norm = self.soft_embedding_norm(inputs_embeds) emb_norm = self.soft_embedding_norm(inputs_embeds)
@ -495,8 +333,8 @@ class Gemma3nMultimodalEmbedder(nn.Module):
@MULTIMODAL_REGISTRY.register_processor(Gemma3MultiModalProcessor, @MULTIMODAL_REGISTRY.register_processor(Gemma3MultiModalProcessor,
info=Gemma3ProcessingInfo, info=Gemma3nProcessingInfo,
dummy_inputs=Gemma3DummyInputsBuilder) dummy_inputs=Gemma3nDummyInputsBuilder)
class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal): class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
@ -532,8 +370,10 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal):
self.vision_tower = AutoModel.from_config(config=config.vision_config) self.vision_tower = AutoModel.from_config(config=config.vision_config)
self.audio_tower = AutoModel.from_config(config=config.audio_config) self.audio_tower = AutoModel.from_config(config=config.audio_config)
self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config, config.text_config) self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config,
self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config, config.text_config) config.text_config)
self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config,
config.text_config)
self.language_model = init_vllm_registered_model( self.language_model = init_vllm_registered_model(
vllm_config=vllm_config, vllm_config=vllm_config,
@ -553,9 +393,9 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal):
assert self.vision_tower is not None assert self.vision_tower is not None
pixel_values = image_input["pixel_values"] pixel_values = image_input["pixel_values"]
vision_outputs = self.vision_tower( vision_outputs = self.vision_tower(pixel_values=pixel_values,
pixel_values=pixel_values, do_pooling=False, return_dict=True do_pooling=False,
).last_hidden_state return_dict=True).last_hidden_state
vision_outputs = vision_outputs.reshape( vision_outputs = vision_outputs.reshape(
vision_outputs.shape[0], vision_outputs.shape[0],
self.config.vision_config.hidden_size, self.config.vision_config.hidden_size,
@ -566,14 +406,16 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal):
return self.embed_vision(inputs_embeds=vision_outputs) return self.embed_vision(inputs_embeds=vision_outputs)
def _process_audio_input( def _process_audio_input(
self, audio_input: Gemma3nAudioInputs, self,
audio_input: Gemma3nAudioInputs,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
assert self.audio_tower is not None assert self.audio_tower is not None
input_features = audio_input["input_features"] input_features = audio_input["input_features"]
input_features_mask = audio_input["input_features_mask"] input_features_mask = audio_input["input_features_mask"]
audio_outputs, audio_mask = self.audio_tower(input_features, input_features_mask) audio_outputs, audio_mask = self.audio_tower(input_features,
input_features_mask)
return self.embed_audio(inputs_embeds=audio_outputs), audio_mask return self.embed_audio(inputs_embeds=audio_outputs), audio_mask
def get_language_model(self) -> torch.nn.Module: def get_language_model(self) -> torch.nn.Module:
return self.language_model return self.language_model

View File

@ -617,7 +617,8 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
) -> Mapping[str, MultiModalFieldConfig]: ) -> Mapping[str, MultiModalFieldConfig]:
return dict( return dict(
pixel_values=MultiModalFieldConfig.batched("image"), pixel_values=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"), input_features=MultiModalFieldConfig.batched("audio"),
input_features_mask=MultiModalFieldConfig.batched("audio"),
) )
def _get_prompt_updates( def _get_prompt_updates(