mirror of https://github.com/vllm-project/vllm.git
parent
bfd63b1b10
commit
b801bf30d7
|
@ -1,59 +1,58 @@
|
||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import math
|
|
||||||
from collections.abc import Iterable, Mapping, Sequence
|
from collections.abc import Iterable, Mapping, Sequence
|
||||||
from typing import Any, Literal, Optional, TypedDict
|
from typing import Any, Optional, TypedDict, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import BatchFeature, Gemma3nConfig, Gemma3nProcessor
|
from transformers import AutoModel, BatchFeature
|
||||||
from transformers.models.gemma3n.processing_gemma3n import Gemma3nProcessorKwargs
|
from transformers.models.gemma3n import (Gemma3nAudioConfig,
|
||||||
from transformers import AutoModel
|
Gemma3nAudioFeatureExtractor,
|
||||||
|
Gemma3nConfig, Gemma3nProcessor,
|
||||||
|
Gemma3nTextConfig,
|
||||||
|
Gemma3nVisionConfig)
|
||||||
|
from transformers.models.siglip import SiglipImageProcessorFast
|
||||||
|
|
||||||
import vllm.envs as envs
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
|
from vllm.model_executor.layers.linear import RowParallelLinear
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding)
|
VocabParallelEmbedding)
|
||||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|
||||||
MergedColumnParallelLinear,
|
|
||||||
QKVParallelLinear,
|
|
||||||
ReplicatedLinear,
|
|
||||||
RowParallelLinear)
|
|
||||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalFieldConfig,
|
||||||
MultiModalKwargs)
|
MultiModalKwargs)
|
||||||
from vllm.multimodal.parse import (ImageProcessorItems, ImageSize,
|
from vllm.multimodal.parse import ImageProcessorItems, MultiModalDataItems
|
||||||
MultiModalDataItems)
|
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
from vllm.multimodal.processing import (BaseMultiModalProcessor,
|
||||||
BaseProcessingInfo, BoundPromptUpdate,
|
BaseProcessingInfo, BoundPromptUpdate,
|
||||||
PlaceholderFeaturesInfo,
|
PlaceholderFeaturesInfo,
|
||||||
PromptReplacement, PromptTargetMatch,
|
PromptReplacement, PromptTargetMatch,
|
||||||
PromptUpdate, PromptUpdateDetails,
|
PromptUpdate, find_mm_placeholders,
|
||||||
find_mm_placeholders,
|
|
||||||
replace_token_matches)
|
replace_token_matches)
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
from vllm.multimodal.profiling import BaseDummyInputsBuilder
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
|
from .interfaces import MultiModalEmbeddings, SupportsMultiModal
|
||||||
SupportsMultiModal, SupportsPP)
|
from .utils import (AutoWeightsLoader, WeightsMapper,
|
||||||
from .siglip import SiglipVisionModel
|
|
||||||
from .utils import (AutoWeightsLoader, WeightsMapper, flatten_bn,
|
|
||||||
init_vllm_registered_model, maybe_prefix,
|
init_vllm_registered_model, maybe_prefix,
|
||||||
merge_multimodal_embeddings)
|
merge_multimodal_embeddings)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
# This should be based on model config but we hardcode them for now.
|
||||||
|
TOKENS_PER_IMAGE = 256
|
||||||
|
TOKENS_PER_AUDIO = 188
|
||||||
|
|
||||||
|
|
||||||
class Gemma3nImagePixelInputs(TypedDict):
|
class Gemma3nImagePixelInputs(TypedDict):
|
||||||
pixel_values: torch.Tensor
|
pixel_values: torch.Tensor
|
||||||
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
|
"""Shape: `(batch_size * num_images, num_channels, height, width)`"""
|
||||||
|
|
||||||
|
|
||||||
class Gemma3nAudioInputs(TypedDict):
|
class Gemma3nAudioInputs(TypedDict):
|
||||||
input_features: torch.Tensor
|
input_features: torch.Tensor
|
||||||
"""Shape: `(batch_size * num_audio, seq_length, num_features)`"""
|
"""Shape: `(batch_size * num_audio, seq_length, num_features)`"""
|
||||||
|
@ -64,7 +63,7 @@ class Gemma3nAudioInputs(TypedDict):
|
||||||
Gemma3nImageInputs = Gemma3nImagePixelInputs
|
Gemma3nImageInputs = Gemma3nImagePixelInputs
|
||||||
|
|
||||||
|
|
||||||
class Gemma3ProcessingInfo(BaseProcessingInfo):
|
class Gemma3nProcessingInfo(BaseProcessingInfo):
|
||||||
|
|
||||||
def get_hf_config(self):
|
def get_hf_config(self):
|
||||||
return self.ctx.get_hf_config(Gemma3nConfig)
|
return self.ctx.get_hf_config(Gemma3nConfig)
|
||||||
|
@ -73,171 +72,26 @@ class Gemma3ProcessingInfo(BaseProcessingInfo):
|
||||||
return self.ctx.get_hf_processor(Gemma3nProcessor, **kwargs)
|
return self.ctx.get_hf_processor(Gemma3nProcessor, **kwargs)
|
||||||
|
|
||||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||||
return {"image": None}
|
return {"image": None, "audio": None}
|
||||||
|
|
||||||
def _resolve_image_kwargs(
|
def get_max_tokens_per_item(
|
||||||
self,
|
self, seq_len: int,
|
||||||
processor: Gemma3Processor,
|
mm_counts: Mapping[str, int]) -> Optional[Mapping[str, int]]:
|
||||||
keys: set[str],
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
image_processor = processor.image_processor
|
|
||||||
kwargs = processor._merge_kwargs(
|
|
||||||
Gemma3ProcessorKwargs,
|
|
||||||
tokenizer_init_kwargs=processor.tokenizer.init_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
images_kwargs = kwargs["images_kwargs"]
|
return {"image": TOKENS_PER_IMAGE, "audio": TOKENS_PER_AUDIO}
|
||||||
|
|
||||||
def _resolve_kw(key: str):
|
|
||||||
val = getattr(image_processor, key)
|
|
||||||
if val is None:
|
|
||||||
val = images_kwargs[key]
|
|
||||||
|
|
||||||
return val
|
|
||||||
|
|
||||||
return {k: _resolve_kw(k) for k in keys}
|
|
||||||
|
|
||||||
def get_num_crops(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
image_width: int,
|
|
||||||
image_height: int,
|
|
||||||
processor: Optional[Gemma3Processor],
|
|
||||||
) -> int:
|
|
||||||
if processor is None:
|
|
||||||
processor = self.get_hf_processor()
|
|
||||||
|
|
||||||
images_kwargs = self._resolve_image_kwargs(
|
|
||||||
processor, {
|
|
||||||
"do_pan_and_scan", "pan_and_scan_min_crop_size",
|
|
||||||
"pan_and_scan_max_num_crops",
|
|
||||||
"pan_and_scan_min_ratio_to_activate"
|
|
||||||
})
|
|
||||||
|
|
||||||
do_pan_and_scan = images_kwargs["do_pan_and_scan"]
|
|
||||||
pan_and_scan_min_crop_size = images_kwargs[
|
|
||||||
"pan_and_scan_min_crop_size"]
|
|
||||||
pan_and_scan_max_num_crops = images_kwargs[
|
|
||||||
"pan_and_scan_max_num_crops"]
|
|
||||||
pan_and_scan_min_ratio_to_activate = images_kwargs[
|
|
||||||
"pan_and_scan_min_ratio_to_activate"]
|
|
||||||
|
|
||||||
if not do_pan_and_scan:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
if envs.VLLM_USE_V1:
|
|
||||||
logger.warning_once(
|
|
||||||
"`do_pan_and_scan=True` has suboptimal results on V1 "
|
|
||||||
"because of the simplified attention pattern being used.")
|
|
||||||
|
|
||||||
# Based on Gemma3ImageProcessor.pan_and_scan
|
|
||||||
if image_width >= image_height:
|
|
||||||
if image_width / image_height < pan_and_scan_min_ratio_to_activate:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
num_crops_w = min(
|
|
||||||
int(math.floor(image_width / pan_and_scan_min_crop_size)),
|
|
||||||
int(math.floor(image_width / image_height + 0.5)),
|
|
||||||
)
|
|
||||||
|
|
||||||
num_crops_w = max(2, num_crops_w)
|
|
||||||
num_crops_w = min(pan_and_scan_max_num_crops, num_crops_w)
|
|
||||||
num_crops_h = 1
|
|
||||||
else:
|
|
||||||
if image_height / image_width < pan_and_scan_min_ratio_to_activate:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
num_crops_h = min(
|
|
||||||
int(math.floor(image_height / pan_and_scan_min_crop_size)),
|
|
||||||
int(math.floor(image_height / image_width + 0.5)),
|
|
||||||
)
|
|
||||||
|
|
||||||
num_crops_h = max(2, num_crops_h)
|
|
||||||
num_crops_h = min(pan_and_scan_max_num_crops, num_crops_h)
|
|
||||||
num_crops_w = 1
|
|
||||||
|
|
||||||
crop_size_w = int(math.ceil(image_width / num_crops_w))
|
|
||||||
crop_size_h = int(math.ceil(image_height / num_crops_h))
|
|
||||||
|
|
||||||
if min(crop_size_w, crop_size_h) < pan_and_scan_min_crop_size:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
return num_crops_w * num_crops_h
|
|
||||||
|
|
||||||
def get_image_repl(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
image_width: int,
|
|
||||||
image_height: int,
|
|
||||||
processor: Optional[Gemma3Processor],
|
|
||||||
) -> PromptUpdateDetails[str]:
|
|
||||||
if processor is None:
|
|
||||||
processor = self.get_hf_processor()
|
|
||||||
|
|
||||||
boi_token = processor.boi_token
|
|
||||||
|
|
||||||
num_crops = self.get_num_crops(
|
|
||||||
image_width=image_width,
|
|
||||||
image_height=image_height,
|
|
||||||
processor=processor,
|
|
||||||
)
|
|
||||||
|
|
||||||
if num_crops == 0:
|
|
||||||
image_text = boi_token
|
|
||||||
else:
|
|
||||||
crops_image_tokens = " ".join(boi_token for _ in range(num_crops))
|
|
||||||
image_text = (
|
|
||||||
f"Here is the original image {boi_token} and here are some "
|
|
||||||
f"crops to help you see better {crops_image_tokens}")
|
|
||||||
|
|
||||||
repl_full = image_text.replace(boi_token,
|
|
||||||
processor.full_image_sequence)
|
|
||||||
|
|
||||||
tokenizer = processor.tokenizer
|
|
||||||
vocab = tokenizer.get_vocab()
|
|
||||||
image_token_id = vocab[tokenizer.image_token]
|
|
||||||
|
|
||||||
return PromptUpdateDetails.select_token_id(repl_full, image_token_id)
|
|
||||||
|
|
||||||
def get_num_image_tokens(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
image_width: int,
|
|
||||||
image_height: int,
|
|
||||||
processor: Optional[Gemma3Processor],
|
|
||||||
) -> int:
|
|
||||||
if processor is None:
|
|
||||||
processor = self.get_hf_processor()
|
|
||||||
|
|
||||||
num_crops = self.get_num_crops(
|
|
||||||
image_width=image_width,
|
|
||||||
image_height=image_height,
|
|
||||||
processor=processor,
|
|
||||||
)
|
|
||||||
image_seq_len = processor.image_seq_length
|
|
||||||
|
|
||||||
return (num_crops + 1) * image_seq_len
|
|
||||||
|
|
||||||
def get_image_size_with_most_features(self) -> ImageSize:
|
|
||||||
processor = self.get_hf_processor()
|
|
||||||
|
|
||||||
images_kwargs = self._resolve_image_kwargs(
|
|
||||||
processor, {"pan_and_scan_max_num_crops"})
|
|
||||||
max_num_crops = images_kwargs["pan_and_scan_max_num_crops"]
|
|
||||||
|
|
||||||
# Result in the max possible feature size (h:w = max_num_crops:1)
|
|
||||||
return ImageSize(height=50 * max_num_crops, width=50)
|
|
||||||
|
|
||||||
|
|
||||||
class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
|
class Gemma3nDummyInputsBuilder(BaseDummyInputsBuilder[Gemma3nProcessingInfo]):
|
||||||
|
|
||||||
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
|
||||||
num_images = mm_counts.get("image", 0)
|
num_images = mm_counts.get("image", 0)
|
||||||
|
num_audios = mm_counts.get("audio", 0)
|
||||||
|
|
||||||
processor = self.info.get_hf_processor()
|
processor = self.info.get_hf_processor()
|
||||||
image_token = processor.boi_token
|
image_token = processor.image_token
|
||||||
|
audio_token = processor.audio_token
|
||||||
|
|
||||||
return image_token * num_images
|
return image_token * num_images + audio_token * num_audios
|
||||||
|
|
||||||
def get_dummy_mm_data(
|
def get_dummy_mm_data(
|
||||||
self,
|
self,
|
||||||
|
@ -245,19 +99,26 @@ class Gemma3DummyInputsBuilder(BaseDummyInputsBuilder[Gemma3ProcessingInfo]):
|
||||||
mm_counts: Mapping[str, int],
|
mm_counts: Mapping[str, int],
|
||||||
) -> MultiModalDataDict:
|
) -> MultiModalDataDict:
|
||||||
num_images = mm_counts.get("image", 0)
|
num_images = mm_counts.get("image", 0)
|
||||||
|
num_audios = mm_counts.get("audio", 0)
|
||||||
target_width, target_height = \
|
processor = self.info.get_hf_processor()
|
||||||
self.info.get_image_size_with_most_features()
|
feature_extractor: Gemma3nAudioFeatureExtractor = processor.feature_extractor # noqa: E501
|
||||||
|
audio_len = feature_extractor.max_length
|
||||||
|
image_processor: SiglipImageProcessorFast = processor.image_processor
|
||||||
|
img_width = image_processor.size.get("width", 224)
|
||||||
|
img_height = image_processor.size.get("width", 224)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"image":
|
"image":
|
||||||
self._get_dummy_images(width=target_width,
|
self._get_dummy_images(width=img_width,
|
||||||
height=target_height,
|
height=img_height,
|
||||||
num_images=num_images)
|
num_images=num_images),
|
||||||
|
"audio":
|
||||||
|
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3nProcessingInfo]
|
||||||
|
):
|
||||||
|
|
||||||
def _call_hf_processor(
|
def _call_hf_processor(
|
||||||
self,
|
self,
|
||||||
|
@ -270,27 +131,6 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||||
mm_data,
|
mm_data,
|
||||||
mm_kwargs,
|
mm_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# HF processor pops the `num_crops` kwarg, which is needed by vLLM
|
|
||||||
if (images := mm_data.get("images")) is not None:
|
|
||||||
parsed_images = (self._get_data_parser().parse_mm_data({
|
|
||||||
"image":
|
|
||||||
images
|
|
||||||
}).get_items("image", ImageProcessorItems))
|
|
||||||
image_sizes = [
|
|
||||||
parsed_images.get_image_size(i)
|
|
||||||
for i in range(len(parsed_images))
|
|
||||||
]
|
|
||||||
hf_processor = self.info.get_hf_processor(**mm_kwargs)
|
|
||||||
|
|
||||||
num_crops = [
|
|
||||||
self.info.get_num_crops(image_width=size.width,
|
|
||||||
image_height=size.height,
|
|
||||||
processor=hf_processor)
|
|
||||||
for size in image_sizes
|
|
||||||
]
|
|
||||||
processed_outputs["num_crops"] = torch.tensor(num_crops)
|
|
||||||
|
|
||||||
return processed_outputs
|
return processed_outputs
|
||||||
|
|
||||||
def _get_mm_fields_config(
|
def _get_mm_fields_config(
|
||||||
|
@ -298,12 +138,11 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||||
hf_inputs: BatchFeature,
|
hf_inputs: BatchFeature,
|
||||||
hf_processor_mm_kwargs: Mapping[str, object],
|
hf_processor_mm_kwargs: Mapping[str, object],
|
||||||
) -> Mapping[str, MultiModalFieldConfig]:
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
num_crops = hf_inputs.get("num_crops", torch.empty(0))
|
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
pixel_values=MultiModalFieldConfig.flat_from_sizes(
|
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||||
"image", num_crops + 1),
|
input_features=MultiModalFieldConfig.batched("audio"),
|
||||||
num_crops=MultiModalFieldConfig.batched("image"),
|
input_features_mask=MultiModalFieldConfig.batched("audio"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_prompt_updates(
|
def _get_prompt_updates(
|
||||||
|
@ -421,7 +260,8 @@ class Gemma3MultiModalProcessor(BaseMultiModalProcessor[Gemma3ProcessingInfo]):
|
||||||
|
|
||||||
|
|
||||||
class Gemma3nMultimodalEmbedder(nn.Module):
|
class Gemma3nMultimodalEmbedder(nn.Module):
|
||||||
"""Embeds token ids or soft tokens for multimodal content into language model space."""
|
"""Embeds token ids or soft tokens for multimodal content into language
|
||||||
|
model space."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -436,7 +276,6 @@ class Gemma3nMultimodalEmbedder(nn.Module):
|
||||||
self.vocab_size = multimodal_config.vocab_size
|
self.vocab_size = multimodal_config.vocab_size
|
||||||
self.text_hidden_size = text_config.hidden_size
|
self.text_hidden_size = text_config.hidden_size
|
||||||
|
|
||||||
|
|
||||||
self.embedding = VocabParallelEmbedding(
|
self.embedding = VocabParallelEmbedding(
|
||||||
self.vocab_size,
|
self.vocab_size,
|
||||||
self.multimodal_hidden_size,
|
self.multimodal_hidden_size,
|
||||||
|
@ -478,11 +317,10 @@ class Gemma3nMultimodalEmbedder(nn.Module):
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A torch.Tensor of embeddings with shape `[batch_size, seq_len, self.config.text_config.hidden_size]`.
|
A torch.Tensor of embeddings with shape `[batch_size, seq_len, self.config.text_config.hidden_size]`.
|
||||||
"""
|
""" # noqa: E501
|
||||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You must specify exactly one of input_ids or inputs_embeds"
|
"You must specify exactly one of input_ids or inputs_embeds")
|
||||||
)
|
|
||||||
|
|
||||||
if inputs_embeds is not None:
|
if inputs_embeds is not None:
|
||||||
emb_norm = self.soft_embedding_norm(inputs_embeds)
|
emb_norm = self.soft_embedding_norm(inputs_embeds)
|
||||||
|
@ -495,8 +333,8 @@ class Gemma3nMultimodalEmbedder(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
@MULTIMODAL_REGISTRY.register_processor(Gemma3MultiModalProcessor,
|
@MULTIMODAL_REGISTRY.register_processor(Gemma3MultiModalProcessor,
|
||||||
info=Gemma3ProcessingInfo,
|
info=Gemma3nProcessingInfo,
|
||||||
dummy_inputs=Gemma3DummyInputsBuilder)
|
dummy_inputs=Gemma3nDummyInputsBuilder)
|
||||||
class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal):
|
class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||||
packed_modules_mapping = {
|
packed_modules_mapping = {
|
||||||
"qkv_proj": [
|
"qkv_proj": [
|
||||||
|
@ -532,8 +370,10 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||||
|
|
||||||
self.vision_tower = AutoModel.from_config(config=config.vision_config)
|
self.vision_tower = AutoModel.from_config(config=config.vision_config)
|
||||||
self.audio_tower = AutoModel.from_config(config=config.audio_config)
|
self.audio_tower = AutoModel.from_config(config=config.audio_config)
|
||||||
self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config, config.text_config)
|
self.embed_vision = Gemma3nMultimodalEmbedder(config.vision_config,
|
||||||
self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config, config.text_config)
|
config.text_config)
|
||||||
|
self.embed_audio = Gemma3nMultimodalEmbedder(config.audio_config,
|
||||||
|
config.text_config)
|
||||||
|
|
||||||
self.language_model = init_vllm_registered_model(
|
self.language_model = init_vllm_registered_model(
|
||||||
vllm_config=vllm_config,
|
vllm_config=vllm_config,
|
||||||
|
@ -553,9 +393,9 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||||
assert self.vision_tower is not None
|
assert self.vision_tower is not None
|
||||||
|
|
||||||
pixel_values = image_input["pixel_values"]
|
pixel_values = image_input["pixel_values"]
|
||||||
vision_outputs = self.vision_tower(
|
vision_outputs = self.vision_tower(pixel_values=pixel_values,
|
||||||
pixel_values=pixel_values, do_pooling=False, return_dict=True
|
do_pooling=False,
|
||||||
).last_hidden_state
|
return_dict=True).last_hidden_state
|
||||||
vision_outputs = vision_outputs.reshape(
|
vision_outputs = vision_outputs.reshape(
|
||||||
vision_outputs.shape[0],
|
vision_outputs.shape[0],
|
||||||
self.config.vision_config.hidden_size,
|
self.config.vision_config.hidden_size,
|
||||||
|
@ -566,12 +406,14 @@ class Gemma3nForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||||
return self.embed_vision(inputs_embeds=vision_outputs)
|
return self.embed_vision(inputs_embeds=vision_outputs)
|
||||||
|
|
||||||
def _process_audio_input(
|
def _process_audio_input(
|
||||||
self, audio_input: Gemma3nAudioInputs,
|
self,
|
||||||
|
audio_input: Gemma3nAudioInputs,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
assert self.audio_tower is not None
|
assert self.audio_tower is not None
|
||||||
input_features = audio_input["input_features"]
|
input_features = audio_input["input_features"]
|
||||||
input_features_mask = audio_input["input_features_mask"]
|
input_features_mask = audio_input["input_features_mask"]
|
||||||
audio_outputs, audio_mask = self.audio_tower(input_features, input_features_mask)
|
audio_outputs, audio_mask = self.audio_tower(input_features,
|
||||||
|
input_features_mask)
|
||||||
return self.embed_audio(inputs_embeds=audio_outputs), audio_mask
|
return self.embed_audio(inputs_embeds=audio_outputs), audio_mask
|
||||||
|
|
||||||
def get_language_model(self) -> torch.nn.Module:
|
def get_language_model(self) -> torch.nn.Module:
|
||||||
|
|
|
@ -617,7 +617,8 @@ class QwenVLMultiModalProcessor(BaseMultiModalProcessor[QwenVLProcessingInfo]):
|
||||||
) -> Mapping[str, MultiModalFieldConfig]:
|
) -> Mapping[str, MultiModalFieldConfig]:
|
||||||
return dict(
|
return dict(
|
||||||
pixel_values=MultiModalFieldConfig.batched("image"),
|
pixel_values=MultiModalFieldConfig.batched("image"),
|
||||||
image_embeds=MultiModalFieldConfig.batched("image"),
|
input_features=MultiModalFieldConfig.batched("audio"),
|
||||||
|
input_features_mask=MultiModalFieldConfig.batched("audio"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_prompt_updates(
|
def _get_prompt_updates(
|
||||||
|
|
Loading…
Reference in New Issue