[LMM] Implement merged multimodal processor for whisper (#13278)

This commit is contained in:
Isotr0py 2025-02-23 17:46:03 +08:00 committed by GitHub
parent d5ca2110f1
commit ba5106e519
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 144 additions and 77 deletions

View File

@ -83,11 +83,11 @@ def _test_processing_correctness(
}
tokenizer_encode_kwargs = {}
if model_config.hf_config.model_type == "mllama":
# For Mllama, tokenizer will always add bos_token at the beginning of
# prompt by default, causing hf_processor outputs incorrect token ids.
# So we need use `add_special_tokens=False` here to leave bos_token
# to be added by the processor.
if model_config.hf_config.model_type in ("mllama", "whisper"):
# For some encoder-decoder models, tokenizer will always add bos_token
# at the beginning of prompt by default, causing hf_processor outputs
# incorrect token ids. So we need use `add_special_tokens=False` here
# to leave bos_token to be added by the processor.
tokenizer_encode_kwargs = {"add_special_tokens": False}
for batch_idx in range(num_batches):
@ -173,6 +173,7 @@ def _test_processing_correctness(
"Qwen/Qwen2.5-VL-3B-Instruct",
"Qwen/Qwen2-Audio-7B-Instruct",
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
"openai/whisper-large-v3",
])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])

View File

@ -4,15 +4,15 @@ import math
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
Union)
import numpy as np
import torch
from torch import nn
from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
WhisperProcessor)
from transformers.models.whisper.modeling_whisper import sinusoids
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import INPUT_REGISTRY, DummyData, InputContext
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
@ -25,11 +25,14 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
NestedTensors)
from vllm.multimodal.audio import resample_audio
from vllm.sequence import SequenceData
from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
from vllm.multimodal.parse import (MultiModalDataDict, MultiModalDataItems,
MultiModalDataParser)
from vllm.multimodal.processing import (BaseProcessingInfo,
EncDecMultiModalProcessor,
PromptReplacement)
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
from .interfaces import SupportsMultiModal, SupportsTranscription
from .utils import AutoWeightsLoader, WeightsMapper, make_layers
@ -571,72 +574,126 @@ class WhisperModel(nn.Module):
return loaded_params
def get_max_whisper_audio_tokens(ctx: InputContext) -> int:
return ctx.model_config.hf_config.max_source_positions
class WhisperProcessingInfo(BaseProcessingInfo):
def get_hf_config(self) -> WhisperConfig:
return self.ctx.get_hf_config(WhisperConfig)
def get_hf_processor(self,
sampling_rate: Optional[int] = None
) -> WhisperProcessor:
return self.ctx.get_hf_processor(WhisperProcessor)
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
return {"audio": 1}
def get_feature_extractor(self) -> WhisperFeatureExtractor:
hf_processor = self.get_hf_processor()
feature_extractor = hf_processor.feature_extractor # type: ignore
assert isinstance(feature_extractor, WhisperFeatureExtractor)
return feature_extractor
def get_max_audio_tokens(self) -> int:
return self.get_hf_config().max_source_positions
def get_mm_max_tokens_per_item(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> Mapping[str, int]:
return {"audio": self.get_max_audio_tokens()}
def dummy_encoder_data_for_whisper(ctx: InputContext, seq_len: int,
mm_counts: Mapping[str, int]):
assert mm_counts["audio"] == 1
num_tokens = get_max_whisper_audio_tokens(ctx)
processor = cached_processor_from_config(ctx.model_config)
chunk_length = processor.feature_extractor.chunk_length
sampling_rate = processor.feature_extractor.sampling_rate
num_samples = chunk_length * sampling_rate
return DummyData(
SequenceData.from_prompt_token_counts((0, num_tokens)),
{"audio": [(np.zeros(num_samples), sampling_rate)]},
)
class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]):
def get_dummy_processor_inputs(
self,
seq_len: int,
mm_counts: Mapping[str, int],
) -> ProcessorInputs:
feature_extractor = self.info.get_feature_extractor()
sampling_rate = feature_extractor.sampling_rate
audio_len = feature_extractor.chunk_length * sampling_rate
num_audios = mm_counts.get("audio", 0)
mm_data = {
"audio":
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
}
return ProcessorInputs(
prompt_text="<|startoftranscript|>" * num_audios,
mm_data=mm_data,
)
def input_processor_for_whisper(ctx: InputContext, inputs):
multi_modal_data = inputs["encoder"]["multi_modal_data"]
if isinstance(multi_modal_data["audio"], list):
assert len(multi_modal_data["audio"]) == 1
multi_modal_data["audio"] = multi_modal_data["audio"][0]
# Resample and process audio
audio, orig_sr = multi_modal_data["audio"]
processor = cached_processor_from_config(ctx.model_config)
target_sr = processor.feature_extractor.sampling_rate
audio = resample_audio(audio, orig_sr=orig_sr, target_sr=target_sr)
multi_modal_data["audio"] = (audio, target_sr)
# Pre-allocate placeholder tokens in encoder sequence
num_tokens = get_max_whisper_audio_tokens(ctx)
inputs["encoder"]["prompt_token_ids"] = [0] * num_tokens
return inputs
class WhisperMultiModalProcessor(
EncDecMultiModalProcessor[WhisperProcessingInfo]):
def _get_data_parser(self) -> MultiModalDataParser:
feature_extractor = self.info.get_feature_extractor()
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
def create_encoder_prompt(
self,
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
) -> Union[str, list[int]]:
# Strictly speaking, whisper encoder only accept audio features.
# We create a dummy encoder prompt here which will be padded to
# num_audio_tokens. So that we can create dummy data from this
# for encoder profiling.
return [0]
def _call_hf_processor(
self,
prompt: str,
mm_data: Mapping[str, object],
mm_kwargs: Mapping[str, object],
) -> BatchFeature:
if mm_data:
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
mm_data = dict(audio=mm_data.pop("audios"))
mm_kwargs = dict(
**mm_kwargs,
sampling_rate=feature_extractor.sampling_rate,
)
processed_outputs = super()._call_hf_processor(
prompt=prompt,
mm_data=mm_data,
mm_kwargs=mm_kwargs,
)
if "labels" in processed_outputs:
processed_outputs["input_ids"] = processed_outputs.pop("labels")
return processed_outputs
def _get_mm_fields_config(
self,
hf_inputs: BatchFeature,
hf_processor_mm_kwargs: Mapping[str, object],
) -> Mapping[str, MultiModalFieldConfig]:
return dict(input_features=MultiModalFieldConfig.batched("audio"))
def _get_prompt_replacements(
self,
mm_items: MultiModalDataItems,
hf_processor_mm_kwargs: Mapping[str, object],
out_mm_kwargs: MultiModalKwargs,
) -> list[PromptReplacement]:
num_tokens = self.info.get_max_audio_tokens()
return [
PromptReplacement(
modality="audio",
target=[0],
replacement=[0] * num_tokens,
)
]
def input_mapper_for_whisper(
ctx: InputContext,
multi_modal_data: Union[np.ndarray, List[np.ndarray]],
) -> MultiModalKwargs:
if not isinstance(multi_modal_data, list):
multi_modal_data = [multi_modal_data]
assert len(multi_modal_data) == 1
if len(multi_modal_data) == 0:
return MultiModalKwargs()
processor = cached_processor_from_config(ctx.model_config)
sampling_rate = processor.feature_extractor.sampling_rate
audios = [audio for audio, _ in multi_modal_data]
kwargs = processor(audios,
sampling_rate=sampling_rate,
return_tensors="pt")
kwargs["input_features"] = kwargs["input_features"].squeeze(0).to(
ctx.model_config.dtype)
return MultiModalKwargs(kwargs)
@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_whisper)
@INPUT_REGISTRY.register_input_processor(input_processor_for_whisper)
@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_whisper)
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
"audio", get_max_whisper_audio_tokens)
@MULTIMODAL_REGISTRY.register_processor(WhisperMultiModalProcessor,
info=WhisperProcessingInfo,
dummy_inputs=WhisperDummyInputsBuilder)
class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
SupportsMultiModal):
packed_modules_mapping = {
@ -724,7 +781,8 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
if not isinstance(input_features, (torch.Tensor, list)):
raise ValueError("Incorrect type of audio features. "
f"Got type: {type(input_features)}")
input_features = [feat.to(self.dtype) for feat in input_features]
input_features = torch.cat(
[feat.to(self.dtype) for feat in input_features])
return WhisperAudioInputs(input_features=input_features)

View File

@ -1297,7 +1297,10 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
prompt: Union[str, list[int]],
mm_data: MultiModalDataDict,
) -> Union[str, list[int]]:
"""Create input prompt for the encoder."""
"""
Create input prompt for the encoder. HF processor will be applied on
this prompt during profiling and generation.
"""
raise NotImplementedError
def apply(

View File

@ -166,8 +166,12 @@ class MultiModalProfiler(Generic[_I]):
f"({set(mm_max_tokens_per_item.keys())})")
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
prompt_token_ids = mm_inputs["prompt_token_ids"]
placeholders_by_modality = mm_inputs["mm_placeholders"]
# For encoder-decoder models, use encoder prompt token ids instead of
# decoder prompt to construct dummy seq_data for encoder profiling.
prompt_token_ids = (
mm_inputs["prompt_token_ids"] if not is_encoder_data else
mm_inputs["encoder_prompt_token_ids"]) # type: ignore
total_placeholders_by_modality = {
modality: sum(item["length"] for item in placeholders)
@ -188,7 +192,7 @@ class MultiModalProfiler(Generic[_I]):
# V0 does not support chunked prefill.
if (total_len > seq_len and not envs.VLLM_USE_V1) or is_encoder_data:
if total_len > seq_len:
if total_len > seq_len and not is_encoder_data:
logger.warning(
"The context length (%d) of the model is too short "
"to hold the multi-modal embeddings in the worst case "
@ -201,7 +205,8 @@ class MultiModalProfiler(Generic[_I]):
total_placeholders_by_modality)
return DummyData(
seq_data=SequenceData.from_prompt_token_counts((0, seq_len)),
seq_data=SequenceData.from_prompt_token_counts(
(0, max(seq_len, total_len))),
multi_modal_data=None,
multi_modal_placeholders=None,
)