mirror of https://github.com/vllm-project/vllm.git
[LMM] Implement merged multimodal processor for whisper (#13278)
This commit is contained in:
parent
d5ca2110f1
commit
ba5106e519
|
@ -83,11 +83,11 @@ def _test_processing_correctness(
|
|||
}
|
||||
|
||||
tokenizer_encode_kwargs = {}
|
||||
if model_config.hf_config.model_type == "mllama":
|
||||
# For Mllama, tokenizer will always add bos_token at the beginning of
|
||||
# prompt by default, causing hf_processor outputs incorrect token ids.
|
||||
# So we need use `add_special_tokens=False` here to leave bos_token
|
||||
# to be added by the processor.
|
||||
if model_config.hf_config.model_type in ("mllama", "whisper"):
|
||||
# For some encoder-decoder models, tokenizer will always add bos_token
|
||||
# at the beginning of prompt by default, causing hf_processor outputs
|
||||
# incorrect token ids. So we need use `add_special_tokens=False` here
|
||||
# to leave bos_token to be added by the processor.
|
||||
tokenizer_encode_kwargs = {"add_special_tokens": False}
|
||||
|
||||
for batch_idx in range(num_batches):
|
||||
|
@ -173,6 +173,7 @@ def _test_processing_correctness(
|
|||
"Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
"Qwen/Qwen2-Audio-7B-Instruct",
|
||||
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
|
||||
"openai/whisper-large-v3",
|
||||
])
|
||||
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
|
||||
@pytest.mark.parametrize("num_batches", [32])
|
||||
|
|
|
@ -4,15 +4,15 @@ import math
|
|||
from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
|
||||
Union)
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers import (BatchFeature, WhisperConfig, WhisperFeatureExtractor,
|
||||
WhisperProcessor)
|
||||
from transformers.models.whisper.modeling_whisper import sinusoids
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata, AttentionType
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.inputs import INPUT_REGISTRY, DummyData, InputContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
|
@ -25,11 +25,14 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
|||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
|
||||
NestedTensors)
|
||||
from vllm.multimodal.audio import resample_audio
|
||||
from vllm.sequence import SequenceData
|
||||
from vllm.transformers_utils.processor import cached_processor_from_config
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, NestedTensors
|
||||
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargs
|
||||
from vllm.multimodal.parse import (MultiModalDataDict, MultiModalDataItems,
|
||||
MultiModalDataParser)
|
||||
from vllm.multimodal.processing import (BaseProcessingInfo,
|
||||
EncDecMultiModalProcessor,
|
||||
PromptReplacement)
|
||||
from vllm.multimodal.profiling import BaseDummyInputsBuilder, ProcessorInputs
|
||||
|
||||
from .interfaces import SupportsMultiModal, SupportsTranscription
|
||||
from .utils import AutoWeightsLoader, WeightsMapper, make_layers
|
||||
|
@ -571,72 +574,126 @@ class WhisperModel(nn.Module):
|
|||
return loaded_params
|
||||
|
||||
|
||||
def get_max_whisper_audio_tokens(ctx: InputContext) -> int:
|
||||
return ctx.model_config.hf_config.max_source_positions
|
||||
class WhisperProcessingInfo(BaseProcessingInfo):
|
||||
|
||||
def get_hf_config(self) -> WhisperConfig:
|
||||
return self.ctx.get_hf_config(WhisperConfig)
|
||||
|
||||
def get_hf_processor(self,
|
||||
sampling_rate: Optional[int] = None
|
||||
) -> WhisperProcessor:
|
||||
return self.ctx.get_hf_processor(WhisperProcessor)
|
||||
|
||||
def get_supported_mm_limits(self) -> Mapping[str, Optional[int]]:
|
||||
return {"audio": 1}
|
||||
|
||||
def get_feature_extractor(self) -> WhisperFeatureExtractor:
|
||||
hf_processor = self.get_hf_processor()
|
||||
feature_extractor = hf_processor.feature_extractor # type: ignore
|
||||
assert isinstance(feature_extractor, WhisperFeatureExtractor)
|
||||
return feature_extractor
|
||||
|
||||
def get_max_audio_tokens(self) -> int:
|
||||
return self.get_hf_config().max_source_positions
|
||||
|
||||
def get_mm_max_tokens_per_item(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> Mapping[str, int]:
|
||||
return {"audio": self.get_max_audio_tokens()}
|
||||
|
||||
|
||||
def dummy_encoder_data_for_whisper(ctx: InputContext, seq_len: int,
|
||||
mm_counts: Mapping[str, int]):
|
||||
assert mm_counts["audio"] == 1
|
||||
num_tokens = get_max_whisper_audio_tokens(ctx)
|
||||
processor = cached_processor_from_config(ctx.model_config)
|
||||
chunk_length = processor.feature_extractor.chunk_length
|
||||
sampling_rate = processor.feature_extractor.sampling_rate
|
||||
num_samples = chunk_length * sampling_rate
|
||||
return DummyData(
|
||||
SequenceData.from_prompt_token_counts((0, num_tokens)),
|
||||
{"audio": [(np.zeros(num_samples), sampling_rate)]},
|
||||
)
|
||||
class WhisperDummyInputsBuilder(BaseDummyInputsBuilder[WhisperProcessingInfo]):
|
||||
|
||||
def get_dummy_processor_inputs(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> ProcessorInputs:
|
||||
feature_extractor = self.info.get_feature_extractor()
|
||||
|
||||
sampling_rate = feature_extractor.sampling_rate
|
||||
audio_len = feature_extractor.chunk_length * sampling_rate
|
||||
num_audios = mm_counts.get("audio", 0)
|
||||
|
||||
mm_data = {
|
||||
"audio":
|
||||
self._get_dummy_audios(length=audio_len, num_audios=num_audios)
|
||||
}
|
||||
|
||||
return ProcessorInputs(
|
||||
prompt_text="<|startoftranscript|>" * num_audios,
|
||||
mm_data=mm_data,
|
||||
)
|
||||
|
||||
|
||||
def input_processor_for_whisper(ctx: InputContext, inputs):
|
||||
multi_modal_data = inputs["encoder"]["multi_modal_data"]
|
||||
if isinstance(multi_modal_data["audio"], list):
|
||||
assert len(multi_modal_data["audio"]) == 1
|
||||
multi_modal_data["audio"] = multi_modal_data["audio"][0]
|
||||
# Resample and process audio
|
||||
audio, orig_sr = multi_modal_data["audio"]
|
||||
processor = cached_processor_from_config(ctx.model_config)
|
||||
target_sr = processor.feature_extractor.sampling_rate
|
||||
audio = resample_audio(audio, orig_sr=orig_sr, target_sr=target_sr)
|
||||
multi_modal_data["audio"] = (audio, target_sr)
|
||||
# Pre-allocate placeholder tokens in encoder sequence
|
||||
num_tokens = get_max_whisper_audio_tokens(ctx)
|
||||
inputs["encoder"]["prompt_token_ids"] = [0] * num_tokens
|
||||
return inputs
|
||||
class WhisperMultiModalProcessor(
|
||||
EncDecMultiModalProcessor[WhisperProcessingInfo]):
|
||||
|
||||
def _get_data_parser(self) -> MultiModalDataParser:
|
||||
feature_extractor = self.info.get_feature_extractor()
|
||||
return MultiModalDataParser(target_sr=feature_extractor.sampling_rate)
|
||||
|
||||
def create_encoder_prompt(
|
||||
self,
|
||||
prompt: Union[str, list[int]],
|
||||
mm_data: MultiModalDataDict,
|
||||
) -> Union[str, list[int]]:
|
||||
# Strictly speaking, whisper encoder only accept audio features.
|
||||
# We create a dummy encoder prompt here which will be padded to
|
||||
# num_audio_tokens. So that we can create dummy data from this
|
||||
# for encoder profiling.
|
||||
return [0]
|
||||
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
if mm_data:
|
||||
feature_extractor = self.info.get_feature_extractor(**mm_kwargs)
|
||||
mm_data = dict(audio=mm_data.pop("audios"))
|
||||
mm_kwargs = dict(
|
||||
**mm_kwargs,
|
||||
sampling_rate=feature_extractor.sampling_rate,
|
||||
)
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
if "labels" in processed_outputs:
|
||||
processed_outputs["input_ids"] = processed_outputs.pop("labels")
|
||||
return processed_outputs
|
||||
|
||||
def _get_mm_fields_config(
|
||||
self,
|
||||
hf_inputs: BatchFeature,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
) -> Mapping[str, MultiModalFieldConfig]:
|
||||
return dict(input_features=MultiModalFieldConfig.batched("audio"))
|
||||
|
||||
def _get_prompt_replacements(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> list[PromptReplacement]:
|
||||
num_tokens = self.info.get_max_audio_tokens()
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="audio",
|
||||
target=[0],
|
||||
replacement=[0] * num_tokens,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def input_mapper_for_whisper(
|
||||
ctx: InputContext,
|
||||
multi_modal_data: Union[np.ndarray, List[np.ndarray]],
|
||||
) -> MultiModalKwargs:
|
||||
if not isinstance(multi_modal_data, list):
|
||||
multi_modal_data = [multi_modal_data]
|
||||
|
||||
assert len(multi_modal_data) == 1
|
||||
|
||||
if len(multi_modal_data) == 0:
|
||||
return MultiModalKwargs()
|
||||
|
||||
processor = cached_processor_from_config(ctx.model_config)
|
||||
sampling_rate = processor.feature_extractor.sampling_rate
|
||||
|
||||
audios = [audio for audio, _ in multi_modal_data]
|
||||
|
||||
kwargs = processor(audios,
|
||||
sampling_rate=sampling_rate,
|
||||
return_tensors="pt")
|
||||
kwargs["input_features"] = kwargs["input_features"].squeeze(0).to(
|
||||
ctx.model_config.dtype)
|
||||
|
||||
return MultiModalKwargs(kwargs)
|
||||
|
||||
|
||||
@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_whisper)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_whisper)
|
||||
@MULTIMODAL_REGISTRY.register_input_mapper("audio", input_mapper_for_whisper)
|
||||
@MULTIMODAL_REGISTRY.register_max_multimodal_tokens(
|
||||
"audio", get_max_whisper_audio_tokens)
|
||||
@MULTIMODAL_REGISTRY.register_processor(WhisperMultiModalProcessor,
|
||||
info=WhisperProcessingInfo,
|
||||
dummy_inputs=WhisperDummyInputsBuilder)
|
||||
class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
||||
SupportsMultiModal):
|
||||
packed_modules_mapping = {
|
||||
|
@ -724,7 +781,8 @@ class WhisperForConditionalGeneration(nn.Module, SupportsTranscription,
|
|||
if not isinstance(input_features, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of audio features. "
|
||||
f"Got type: {type(input_features)}")
|
||||
input_features = [feat.to(self.dtype) for feat in input_features]
|
||||
input_features = torch.cat(
|
||||
[feat.to(self.dtype) for feat in input_features])
|
||||
|
||||
return WhisperAudioInputs(input_features=input_features)
|
||||
|
||||
|
|
|
@ -1297,7 +1297,10 @@ class EncDecMultiModalProcessor(BaseMultiModalProcessor[_I]):
|
|||
prompt: Union[str, list[int]],
|
||||
mm_data: MultiModalDataDict,
|
||||
) -> Union[str, list[int]]:
|
||||
"""Create input prompt for the encoder."""
|
||||
"""
|
||||
Create input prompt for the encoder. HF processor will be applied on
|
||||
this prompt during profiling and generation.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def apply(
|
||||
|
|
|
@ -166,8 +166,12 @@ class MultiModalProfiler(Generic[_I]):
|
|||
f"({set(mm_max_tokens_per_item.keys())})")
|
||||
|
||||
mm_inputs = self._get_dummy_mm_inputs(seq_len, mm_counts)
|
||||
prompt_token_ids = mm_inputs["prompt_token_ids"]
|
||||
placeholders_by_modality = mm_inputs["mm_placeholders"]
|
||||
# For encoder-decoder models, use encoder prompt token ids instead of
|
||||
# decoder prompt to construct dummy seq_data for encoder profiling.
|
||||
prompt_token_ids = (
|
||||
mm_inputs["prompt_token_ids"] if not is_encoder_data else
|
||||
mm_inputs["encoder_prompt_token_ids"]) # type: ignore
|
||||
|
||||
total_placeholders_by_modality = {
|
||||
modality: sum(item["length"] for item in placeholders)
|
||||
|
@ -188,7 +192,7 @@ class MultiModalProfiler(Generic[_I]):
|
|||
|
||||
# V0 does not support chunked prefill.
|
||||
if (total_len > seq_len and not envs.VLLM_USE_V1) or is_encoder_data:
|
||||
if total_len > seq_len:
|
||||
if total_len > seq_len and not is_encoder_data:
|
||||
logger.warning(
|
||||
"The context length (%d) of the model is too short "
|
||||
"to hold the multi-modal embeddings in the worst case "
|
||||
|
@ -201,7 +205,8 @@ class MultiModalProfiler(Generic[_I]):
|
|||
total_placeholders_by_modality)
|
||||
|
||||
return DummyData(
|
||||
seq_data=SequenceData.from_prompt_token_counts((0, seq_len)),
|
||||
seq_data=SequenceData.from_prompt_token_counts(
|
||||
(0, max(seq_len, total_len))),
|
||||
multi_modal_data=None,
|
||||
multi_modal_placeholders=None,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue