|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""PyTorch Qwen2Audio model.""" |
|
|
|
import math |
|
from dataclasses import dataclass |
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.utils.checkpoint |
|
from torch import nn |
|
from functools import lru_cache |
|
|
|
from transformers.activations import ACT2FN |
|
from transformers.cache_utils import Cache, EncoderDecoderCache, StaticCache |
|
from transformers.generation import GenerationMixin |
|
from transformers.modeling_outputs import BaseModelOutput, ModelOutput, CausalLMOutputWithPast |
|
from transformers.modeling_utils import PreTrainedModel |
|
from transformers.utils import ( |
|
add_start_docstrings, |
|
add_start_docstrings_to_model_forward, |
|
is_flash_attn_2_available, |
|
is_flash_attn_greater_or_equal_2_10, |
|
logging, |
|
replace_return_docstrings, |
|
) |
|
from transformers import AutoModel, AutoModelForCausalLM, AutoConfig, SeamlessM4Tv2Model, Qwen2ForCausalLM, Qwen2PreTrainedModel, Qwen2Model |
|
from transformers.models.seamless_m4t_v2.modeling_seamless_m4t_v2 import SeamlessM4Tv2SpeechEncoder |
|
from .configuration_qwen2_mm import Qwen2MMConfig |
|
from torch.nn import CrossEntropyLoss, LayerNorm |
|
|
|
if is_flash_attn_2_available(): |
|
from transformers.modeling_flash_attention_utils import _flash_attention_forward |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
_CONFIG_FOR_DOC = "Qwen2MMConfig" |
|
|
|
|
|
|
|
class Qwen2AudioMultiModalProjector(nn.Module): |
|
def __init__(self, config: Qwen2MMConfig): |
|
super().__init__() |
|
self.linear = nn.Linear(config.audio_config.hidden_size, config.hidden_size, bias=True) |
|
|
|
def forward(self, audio_features): |
|
hidden_states = self.linear(audio_features) |
|
return hidden_states |
|
|
|
|
|
class Qwen2MMPreTrainedModel(PreTrainedModel): |
|
config_class = Qwen2MMConfig |
|
base_model_prefix = "model" |
|
supports_gradient_checkpointing = True |
|
_supports_flash_attn_2 = True |
|
_supports_sdpa = True |
|
|
|
|
|
|
|
class Qwen2MMForConditionalGeneration(Qwen2MMPreTrainedModel, GenerationMixin): |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.audio_tower = SeamlessM4Tv2SpeechEncoder(config.audio_config) |
|
|
|
self.audio_projector = Qwen2AudioMultiModalProjector(config) |
|
|
|
self.vocab_size = config.vocab_size |
|
''' |
|
tmp = AutoModelForCausalLM.from_pretrained("/mnt/diskhd/Backup/DownloadModel/Qwen2.5-7B-Instruct/") |
|
self.language_model = tmp.model |
|
self.lm_head = tmp.lm_head |
|
''' |
|
|
|
|
|
|
|
|
|
|
|
self.language_model = Qwen2Model(config) |
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
self.padding_side = "left" |
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
|
|
def get_input_embeddings(self): |
|
return self.language_model.get_input_embeddings() |
|
|
|
|
|
def set_input_embeddings(self, value): |
|
self.language_model.set_input_embeddings(value) |
|
|
|
|
|
def get_output_embeddings(self): |
|
return self.language_model.get_output_embeddings() |
|
|
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
self.language_model.set_output_embeddings(new_embeddings) |
|
''' |
|
|
|
|
|
def get_input_embeddings(self): |
|
return self.language_model.embed_tokens |
|
|
|
def set_input_embeddings(self, value): |
|
self.language_model.embed_tokens = value |
|
|
|
def get_output_embeddings(self): |
|
return self.lm_head |
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
self.lm_head = new_embeddings |
|
|
|
def set_decoder(self, decoder): |
|
self.language_model = decoder |
|
|
|
def get_decoder(self): |
|
return self.language_model |
|
|
|
''' |
|
|
|
def _update_model_kwargs_for_generation( |
|
self, |
|
outputs: ModelOutput, |
|
model_kwargs: Dict[str, Any], |
|
is_encoder_decoder: bool = False, |
|
num_new_tokens: int = 1, |
|
) -> Dict[str, Any]: |
|
model_kwargs = super()._update_model_kwargs_for_generation( |
|
outputs=outputs, |
|
model_kwargs=model_kwargs, |
|
is_encoder_decoder=is_encoder_decoder, |
|
num_new_tokens=num_new_tokens, |
|
) |
|
|
|
if getattr(outputs, "rope_deltas", None) is not None: |
|
model_kwargs["rope_deltas"] = outputs.rope_deltas |
|
|
|
return model_kwargs |
|
|
|
|
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
labels: Optional[torch.LongTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
pixel_values: Optional[torch.Tensor] = None, |
|
pixel_values_videos: Optional[torch.FloatTensor] = None, |
|
audio_values: Optional[torch.Tensor] = None, |
|
image_grid_thw: Optional[torch.LongTensor] = None, |
|
video_grid_thw: Optional[torch.LongTensor] = None, |
|
audio_grid_thw: Optional[torch.LongTensor] = None, |
|
audio_attention_mask: Optional[torch.LongTensor] = None, |
|
rope_deltas: Optional[torch.LongTensor] = None, |
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
r""" |
|
Args: |
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., |
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored |
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. |
|
|
|
Returns: |
|
|
|
Example: |
|
|
|
```python |
|
>>> from PIL import Image |
|
>>> import requests |
|
>>> from transformers import AutoProcessor, Qwen2VLForConditionalGeneration |
|
|
|
>>> model = Qwen2VLForConditionalGeneration.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") |
|
>>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct") |
|
|
|
>>> messages = [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "image"}, |
|
{"type": "text", "text": "What is shown in this image?"}, |
|
], |
|
}, |
|
] |
|
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" |
|
>>> image = Image.open(requests.get(url, stream=True).raw) |
|
|
|
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
|
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos]) |
|
|
|
>>> # Generate |
|
>>> generate_ids = model.generate(inputs.input_ids, max_length=30) |
|
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
|
"The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..." |
|
```""" |
|
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.language_model.embed_tokens(input_ids) |
|
if audio_values is not None: |
|
audio_values = audio_values.type(self.audio_tower.dtype) |
|
audio_embeds = self.audio_tower(input_features = audio_values, attention_mask = audio_attention_mask).last_hidden_state |
|
audio_embeds = self.audio_projector(audio_embeds) |
|
|
|
|
|
tmp = [] |
|
for audio_embed, audio_token_num in zip(audio_embeds, audio_grid_thw): |
|
|
|
tmp.append(audio_embed[:audio_token_num, :]) |
|
audio_embeds = torch.cat(tmp) |
|
|
|
|
|
n_audio_tokens = (input_ids == self.config.audio_token_id).sum().item() |
|
n_audio_features = audio_embeds.shape[0] |
|
if n_audio_tokens != n_audio_features: |
|
print( |
|
f"Audio features and audio tokens do not match: tokens: {n_audio_tokens}, features {n_audio_features}" |
|
) |
|
audio_mask = ( |
|
(input_ids == self.config.audio_token_id) |
|
.unsqueeze(-1) |
|
.expand_as(inputs_embeds) |
|
.to(inputs_embeds.device) |
|
) |
|
audio_embeds = audio_embeds.to(inputs_embeds.device, inputs_embeds.dtype) |
|
inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_embeds) |
|
|
|
|
|
if attention_mask is not None: |
|
attention_mask = attention_mask.to(inputs_embeds.device) |
|
|
|
outputs = self.language_model( |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
inputs_embeds=inputs_embeds, |
|
use_cache=use_cache, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
|
|
|
|
|
|
hidden_states = outputs[0] |
|
logits = self.lm_head(hidden_states) |
|
|
|
|
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
logits = logits.float() |
|
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
loss_fct = CrossEntropyLoss() |
|
|
|
|
|
|
|
shift_logits = shift_logits.view(-1, self.config.vocab_size) |
|
shift_labels = shift_labels.view(-1) |
|
|
|
shift_labels = shift_labels.to(shift_logits.device) |
|
loss = loss_fct(shift_logits, shift_labels) |
|
|
|
if not return_dict: |
|
output = (logits,) + outputs[1:] |
|
return (loss,) + output if loss is not None else output |
|
|
|
return CausalLMOutputWithPast( |
|
loss=loss, |
|
logits=logits, |
|
past_key_values=outputs.past_key_values, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
|
|
|
|
def prepare_inputs_for_generation( |
|
self, |
|
input_ids, |
|
past_key_values=None, |
|
attention_mask=None, |
|
inputs_embeds=None, |
|
cache_position=None, |
|
position_ids=None, |
|
use_cache=True, |
|
pixel_values=None, |
|
pixel_values_videos=None, |
|
audio_values=None, |
|
image_grid_thw=None, |
|
video_grid_thw=None, |
|
audio_grid_thw=None, |
|
audio_attention_mask=None, |
|
**kwargs, |
|
): |
|
|
|
|
|
|
|
|
|
|
|
if past_key_values is not None: |
|
if inputs_embeds is not None: |
|
input_ids = input_ids[:, -cache_position.shape[0] :] |
|
elif input_ids.shape[1] != cache_position.shape[0]: |
|
input_ids = input_ids[:, cache_position] |
|
|
|
|
|
if cache_position[0] != 0: |
|
pixel_values = None |
|
pixel_values_videos = None |
|
audio_values = None |
|
|
|
|
|
if inputs_embeds is not None and cache_position[0] == 0: |
|
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} |
|
else: |
|
model_inputs = {"input_ids": input_ids, "inputs_embeds": None} |
|
|
|
|
|
model_inputs.update( |
|
{ |
|
"position_ids": position_ids, |
|
"past_key_values": past_key_values, |
|
"use_cache": use_cache, |
|
"attention_mask": attention_mask, |
|
"pixel_values": pixel_values, |
|
"pixel_values_videos": pixel_values_videos, |
|
"audio_values": audio_values, |
|
"image_grid_thw": image_grid_thw, |
|
"video_grid_thw": video_grid_thw, |
|
"audio_grid_thw": audio_grid_thw, |
|
} |
|
) |
|
|
|
return model_inputs |
|
|