Signed-off-by: ShriKode <shrikode@gmail.com>
This commit is contained in:
ShriKode 2025-06-28 22:21:17 +00:00
parent bfd63b1b10
commit b801bf30d7
2 changed files with 66 additions and 223 deletions

View File

@ -1,59 +1,58 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from collections.abc import Iterable, Mapping, Sequence
from typing import Any, Literal, Optional, TypedDict
from typing import Any, Optional, TypedDict, Union
import torch
from torch import nn
from transformers import BatchFeature, Gemma3nConfig, Gemma3nProcessor
from transformers.models.gemma3n.processing_gemma3n import Gemma3nProcessorKwargs
from transformers import AutoModel
from transformers import AutoModel, BatchFeature
from transformers.models.gemma3n import (Gemma3nAudioConfig,
Gemma3nAudioFeatureExtractor,
Gemma3nConfig, Gemma3nProcessor,
Gemma3nTextConfig,
Gemma3nVisionConfig)
from transformers.models.siglip import SiglipImageProcessorFast
import vllm.envs as envs
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import RowParallelLinear
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
MultiModalKwargs)
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
MultiModalDataItems)
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems
# yapf: disable
from vllm.multimodal.processing import (BaseMultiModalProcessor,
BaseProcessingInfo, BoundPromptUpdate,
PlaceholderFeaturesInfo,
PromptReplacement, PromptTargetMatch,
PromptUpdate, PromptUpdateDetails,
find_mm_placeholders,
PromptUpdate, find_mm_placeholders,
replace_token_matches)
# yapf: enable
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
SupportsMultiModal, SupportsPP)
from .siglip import SiglipVisionModel
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
from .interfaces import MultiModalEmbeddings, SupportsMultiModal
from .utils import (AutoWeightsLoader, WeightsMapper,
init_vllm_registered_model, maybe_prefix,
merge_multimodal_embeddings)
logger = init_logger(__name__)
# This should be based on model config but we hardcode them for now.
TOKENS_PER_IMAGE = 256
TOKENS_PER_AUDIO = 188
class Gemma3nImagePixelInputs(TypedDict):
pixel_values: torch.Tensor
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
class Gemma3nAudioInputs(TypedDict):
input_features: torch.Tensor
"""Shape: `(batch_size * num_audio, seq_length, num_features)`"""
@ -64,7 +63,7 @@ class Gemma3nAudioInputs(TypedDict):
Gemma3nImageInputs = Gemma3nImagePixelInputs
class Gemma3ProcessingInfo(BaseProcessingInfo):
class Gemma3nProcessingInfo(BaseProcessingInfo):
def get_hf_config(self):
return self.ctx.get_hf_config(Gemma3nConfig)
@ -73,171 +72,26 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
return self.ctx.get_hf_processor(Gemma3nProcessor, **kwargs)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"image": None}
return {"image": None, "audio": None}
def _resolve_image_kwargs(
self,
processor: Gemma3Processor,
keys: set[str],
) -> dict[str, Any]:
image_processor = processor.image_processor
kwargs = processor._merge_kwargs(
Gemma3ProcessorKwargs,
tokenizer_init_kwargs=processor.tokenizer.init_kwargs,
)
def get_max_tokens_per_item(
self, seq_len: int,
mm_counts: Mapping[str, int]) -> Optional[Mapping[str, int]]:
images_kwargs = kwargs["images_kwargs"]
def _resolve_kw(key: str):
val = getattr(image_processor, key)
if val is None:
val = images_kwargs[key]
return val
return {k: _resolve_kw(k) for k in keys}
def get_num_crops(
self,
*,
image_width: int,
image_height: int,
processor: Optional[Gemma3Processor],
) -> int:
if processor is None:
processor = self.get_hf_processor()
images_kwargs = self._resolve_image_kwargs(
processor, {
"do_pan_and_scan", "pan_and_scan_min_crop_size",
"pan_and_scan_max_num_crops",
"pan_and_scan_min_ratio_to_activate"
})
do_pan_and_scan = images_kwargs["do_pan_and_scan"]
pan_and_scan_min_crop_size = images_kwargs[
"pan_and_scan_min_crop_size"]
pan_and_scan_max_num_crops = images_kwargs[
"pan_and_scan_max_num_crops"]
pan_and_scan_min_ratio_to_activate = images_kwargs[
"pan_and_scan_min_ratio_to_activate"]
if not do_pan_and_scan:
return 0
if envs.VLLM_USE_V1:
logger.warning_once(
"`do_pan_and_scan=True` has suboptimal results on V1 "
"because of the simplified attention pattern being used.")
# Based on Gemma3ImageProcessor.pan_and_scan
if image_width >= image_height:
if image_width / image_height < pan_and_scan_min_ratio_to_activate:
return 0
num_crops_w = min(
int(math.floor(image_width / pan_and_scan_min_crop_size)),
int(math.floor(image_width / image_height + 0.5)),
)
num_crops_w = max(2, num_crops_w)
num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w)
num_crops_h = 1
else:
if image_height / image_width < pan_and_scan_min_ratio_to_activate:
return 0
num_crops_h = min(
int(math.floor(image_height / pan_and_scan_min_crop_size)),
int(math.floor(image_height / image_width + 0.5)),
)
num_crops_h = max(2, num_crops_h)
num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h)
num_crops_w = 1
crop_size_w = int(math.ceil(image_width / num_crops_w))
crop_size_h = int(math.ceil(image_height / num_crops_h))
if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size:
return 0
return num_crops_w * num_crops_h
def get_image_repl(
self,
*,
image_width: int,
image_height: int,
processor: Optional[Gemma3Processor],
) -> PromptUpdateDetails[str]:
if processor is None:
processor = self.get_hf_processor()
boi_token = processor.boi_token
num_crops = self.get_num_crops(
image_width=image_width,
image_height=image_height,
processor=processor,
)
if num_crops == 0:
image_text = boi_token
else:
crops_image_tokens = " ".join(boi_token for _ in range(num_crops))
image_text = (
f"Here is the original image {boi_token} and here are some "
f"crops to help you see better {crops_image_tokens}")
repl_full = image_text.replace(boi_token,
processor.full_image_sequence)
tokenizer = processor.tokenizer
vocab = tokenizer.get_vocab()
image_token_id = vocab[tokenizer.image_token]
return PromptUpdateDetails.select_token_id(repl_full, image_token_id)
def get_num_image_tokens(
self,
*,
image_width: int,
image_height: int,
processor: Optional[Gemma3Processor],
) -> int:
if processor is None:
processor = self.get_hf_processor()
num_crops = self.get_num_crops(
image_width=image_width,
image_height=image_height,
processor=processor,
)
image_seq_len = processor.image_seq_length
return (num_crops + 1) * image_seq_len
def get_image_size_with_most_features(self) -> ImageSize:
processor = self.get_hf_processor()
images_kwargs = self._resolve_image_kwargs(
processor, {"pan_and_scan_max_num_crops"})
max_num_crops = images_kwargs["pan_and_scan_max_num_crops"]
# Result in the max possible feature size (h:w = max_num_crops:1)
return ImageSize(height=50 * max_num_crops, width=50)
return {"image": TOKENS_PER_IMAGE, "audio": TOKENS_PER_AUDIO}
class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
class Gemma3nDummyInputsBuilder(BaseDummyInputsBuilder[Gemma3nProcessingInfo]):
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
num_images = mm_counts.get("image", 0)
num_audios = mm_counts.get("audio", 0)
processor = self.info.get_hf_processor()
image_token = processor.boi_token
image_token = processor.image_token
audio_token = processor.audio_token
return image_token * num_images
return image_token * num_images + audio_token * num_audios
def get_dummy_mm_data(
self,
@ -245,19 +99,26 @@ class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
mm_counts: Mapping[str, int],
) -> MultiModalDataDict:
num_images = mm_counts.get("image", 0)
target_width, target_height = \
self.info.get_image_size_with_most_features()
num_audios = mm_counts.get("audio", 0)
processor = self.info.get_hf_processor()
feature_extractor: Gemma3nAudioFeatureExtractor = processor.feature_extractor # noqa: E501
audio_len = feature_extractor.max_length
image_processor: SiglipImageProcessorFast = processor.image_processor
img_width = image_processor.size.get("width", 224)
img_height = image_processor.size.get("width", 224)
return {
"image":
self._get_dummy_images(width=target_width,
height=target_height,
num_images=num_images)
self._get_dummy_images(width=img_width,
height=img_height,
num_images=num_images),
"audio":
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
}
class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
):
def _call_hf_processor(
self,
@ -270,27 +131,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
mm_data,
mm_kwargs,
)
# HF processor pops the `num_crops` kwarg, which is needed by vLLM
if (images := mm_data.get("images")) is not None:
parsed_images = (self._get_data_parser().parse_mm_data({
"image":
images
}).get_items("image", ImageProcessorItems))
image_sizes = [
parsed_images.get_image_size(i)
for i in range(len(parsed_images))
]
hf_processor = self.info.get_hf_processor(**mm_kwargs)
num_crops = [
self.info.get_num_crops(image_width=size.width,
image_height=size.height,
processor=hf_processor)
for size in image_sizes
]
processed_outputs["num_crops"] = torch.tensor(num_crops)
return processed_outputs
def _get_mm_fields_config(
@ -298,12 +138,11 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
num_crops = hf_inputs.get("num_crops", torch.empty(0))
return dict(
pixel_values=MultiModalFieldConfig.flat_from_sizes(
"image", num_crops + 1),
num_crops=MultiModalFieldConfig.batched("image"),
pixel_values=MultiModalFieldConfig.batched("image"),
input_features=MultiModalFieldConfig.batched("audio"),
input_features_mask=MultiModalFieldConfig.batched("audio"),
)
def _get_prompt_updates(
@ -421,7 +260,8 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
class Gemma3nMultimodalEmbedder(nn.Module):
"""Embeds token ids or soft tokens for multimodal content into language model space."""
"""Embeds token ids or soft tokens for multimodal content into language
model space."""
def __init__(
self,
@ -436,7 +276,6 @@ class Gemma3nMultimodalEmbedder(nn.Module):
self.vocab_size = multimodal_config.vocab_size
self.text_hidden_size = text_config.hidden_size
self.embedding = VocabParallelEmbedding(
self.vocab_size,
self.multimodal_hidden_size,
@ -478,11 +317,10 @@ class Gemma3nMultimodalEmbedder(nn.Module):
Returns:
A torch.Tensor of embeddings with shape `[batch_size, seq_len, self.config.text_config.hidden_size]`.
"""
""" # noqa: E501
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You must specify exactly one of input_ids or inputs_embeds"
)
"You must specify exactly one of input_ids or inputs_embeds")
if inputs_embeds is not None:
emb_norm = self.soft_embedding_norm(inputs_embeds)
@ -495,8 +333,8 @@ class Gemma3nMultimodalEmbedder(nn.Module):
@MULTIMODAL_REGISTRY.register_processor(Gemma3MultiModalProcessor,
info=Gemma3ProcessingInfo,
dummy_inputs=Gemma3DummyInputsBuilder)
info=Gemma3nProcessingInfo,
dummy_inputs=Gemma3nDummyInputsBuilder)
class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal):
packed_modules_mapping = {
"qkv_proj": [
@ -532,8 +370,10 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal):
self.vision_tower = AutoModel.from_config(config=config.vision_config)
self.audio_tower = AutoModel.from_config(config=config.audio_config)
self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config, config.text_config)
self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config, config.text_config)
self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config,
config.text_config)
self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config,
config.text_config)
self.language_model = init_vllm_registered_model(
vllm_config=vllm_config,
@ -553,9 +393,9 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal):
assert self.vision_tower is not None
pixel_values = image_input["pixel_values"]
vision_outputs = self.vision_tower(
pixel_values=pixel_values, do_pooling=False, return_dict=True
).last_hidden_state
vision_outputs = self.vision_tower(pixel_values=pixel_values,
do_pooling=False,
return_dict=True).last_hidden_state
vision_outputs = vision_outputs.reshape(
vision_outputs.shape[0],
self.config.vision_config.hidden_size,
@ -566,14 +406,16 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal):
return self.embed_vision(inputs_embeds=vision_outputs)
def _process_audio_input(
self, audio_input: Gemma3nAudioInputs,
self,
audio_input: Gemma3nAudioInputs,
) -> tuple[torch.Tensor, torch.Tensor]:
assert self.audio_tower is not None
input_features = audio_input["input_features"]
input_features_mask = audio_input["input_features_mask"]
audio_outputs, audio_mask = self.audio_tower(input_features, input_features_mask)
audio_outputs, audio_mask = self.audio_tower(input_features,
input_features_mask)
return self.embed_audio(inputs_embeds=audio_outputs), audio_mask
def get_language_model(self) -> torch.nn.Module:
return self.language_model

View File

@ -617,7 +617,8 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
) -> Mapping[str, MultiModalFieldConfig]:
return dict(
pixel_values=MultiModalFieldConfig.batched("image"),
image_embeds=MultiModalFieldConfig.batched("image"),
input_features=MultiModalFieldConfig.batched("audio"),
input_features_mask=MultiModalFieldConfig.batched("audio"),
)
def _get_prompt_updates(