|
import os |
|
import sys |
|
|
|
import mmengine |
|
import torch |
|
import torch.nn as nn |
|
from mmengine.device import get_device |
|
from transformers import StoppingCriteriaList |
|
|
|
from opencompass.registry import MM_MODELS |
|
|
|
from .utils import StoppingCriteriaSub |
|
|
|
|
|
class LayerNorm(nn.LayerNorm): |
|
"""Subclass torch's LayerNorm to handle fp16.""" |
|
|
|
def forward(self, x: torch.Tensor): |
|
orig_type = x.dtype |
|
ret = super().forward(x.type(torch.float32)) |
|
return ret.type(orig_type) |
|
|
|
|
|
def load_package(): |
|
"""Load required packages from MiniGPT-4.""" |
|
current_file_path = os.path.abspath(__file__) |
|
current_folder_path = os.path.dirname(current_file_path) |
|
|
|
sys.path.append(os.path.join(current_folder_path, 'MiniGPT-4')) |
|
|
|
try: |
|
|
|
from minigpt4.models.minigpt4 import MiniGPT4 |
|
except ImportError: |
|
|
|
from minigpt4.models.mini_gpt4 import MiniGPT4 |
|
|
|
sys.path.pop(-1) |
|
|
|
return MiniGPT4 |
|
|
|
|
|
MiniGPT4 = load_package() |
|
|
|
|
|
@MM_MODELS.register_module('minigpt-4') |
|
class MiniGPT4Inferencer(MiniGPT4): |
|
"""Inference code of MiniGPT-4. |
|
|
|
Args: |
|
llama_model (str): The path of vicuna path. |
|
prompt_constructor (dict): The config of prompt constructor. |
|
post_processor (dict): The config of post processor. |
|
do_sample (bool): Whether use sampling. Defaults to False. |
|
max_length (int): The max length of output. Defaults to 30. |
|
img_size (int): The size of image. Defaults to 224. |
|
low_resource (bool): Whether loaded in low precision. |
|
Defaults to False. |
|
is_caption_task (bool): Whether the task is caption task. |
|
Defaults to False. |
|
""" |
|
|
|
def __init__(self, |
|
llama_model: str, |
|
prompt_constructor: dict, |
|
post_processor: dict, |
|
do_sample: bool = False, |
|
max_length: int = 30, |
|
img_size: int = 224, |
|
low_resource: bool = False, |
|
is_caption_task: bool = False, |
|
mode: str = 'generation', |
|
n_segments: int = 1) -> None: |
|
super().__init__(llama_model=llama_model, |
|
low_resource=low_resource, |
|
img_size=img_size) |
|
self.mode = mode |
|
self.n_segments = n_segments |
|
|
|
cur_device = get_device() |
|
stop_words_ids = [ |
|
torch.tensor([835]).to(cur_device), |
|
torch.tensor([2277, 29937]).to(cur_device), |
|
] |
|
self.stopping_criteria = StoppingCriteriaList( |
|
[StoppingCriteriaSub(stops=stop_words_ids)]) |
|
|
|
self.prompt_constructor = mmengine.registry.build_from_cfg( |
|
prompt_constructor, MM_MODELS) |
|
if post_processor is not None: |
|
self.post_processor = mmengine.registry.build_from_cfg( |
|
post_processor, MM_MODELS) |
|
self.do_sample = do_sample |
|
self.max_length = max_length |
|
self.is_caption_task = is_caption_task |
|
|
|
def forward(self, batch): |
|
if self.mode == 'generation': |
|
return self.generate(batch) |
|
elif self.mode == 'loss': |
|
return self.loss(batch) |
|
else: |
|
raise RuntimeError(f'Invalid mode "{self.mode}".') |
|
|
|
def encode_img(self, image): |
|
device = image.device |
|
|
|
with self.maybe_autocast(): |
|
if image.dim() == 5: |
|
inputs_llama, atts_llama = [], [] |
|
for j in range(image.size(2)): |
|
this_frame = image[:, :, j, :, :] |
|
frame_embeds = self.ln_vision( |
|
self.visual_encoder(this_frame)) |
|
frame_atts = torch.ones(frame_embeds.size()[:-1], |
|
dtype=torch.long).to(image.device) |
|
|
|
query_tokens = self.query_tokens.expand( |
|
frame_embeds.shape[0], -1, -1) |
|
frame_query_output = self.Qformer.bert( |
|
query_embeds=query_tokens, |
|
encoder_hidden_states=frame_embeds, |
|
encoder_attention_mask=frame_atts, |
|
return_dict=True, |
|
) |
|
|
|
frame_inputs_llama = self.llama_proj( |
|
frame_query_output.last_hidden_state[:, :query_tokens. |
|
size(1), :]) |
|
frame_atts_llama = torch.ones( |
|
frame_inputs_llama.size()[:-1], |
|
dtype=torch.long).to(image.device) |
|
inputs_llama.append(frame_inputs_llama) |
|
atts_llama.append(frame_atts_llama) |
|
inputs_llama = torch.cat(inputs_llama, dim=1) |
|
atts_llama = torch.cat(atts_llama, dim=1) |
|
else: |
|
image_embeds = self.ln_vision( |
|
self.visual_encoder(image)).to(device) |
|
image_atts = torch.ones(image_embeds.size()[:-1], |
|
dtype=torch.long).to(device) |
|
|
|
query_tokens = self.query_tokens.expand( |
|
image_embeds.shape[0], -1, -1) |
|
query_output = self.Qformer.bert( |
|
query_embeds=query_tokens, |
|
encoder_hidden_states=image_embeds, |
|
encoder_attention_mask=image_atts, |
|
return_dict=True, |
|
) |
|
|
|
inputs_llama = self.llama_proj(query_output.last_hidden_state) |
|
atts_llama = torch.ones(inputs_llama.size()[:-1], |
|
dtype=torch.long).to(image.device) |
|
return inputs_llama, atts_llama |
|
|
|
def pack_inputs(self, batch): |
|
images = [image.unsqueeze(0) for image in batch['inputs']] |
|
data_samples = [data_sample for data_sample in batch['data_samples']] |
|
images = torch.cat(images, dim=0).to(get_device()) |
|
inputs = {'image': images, 'data_samples': data_samples} |
|
return inputs |
|
|
|
def generate(self, batch): |
|
inputs = self.pack_inputs(batch) |
|
inputs = self.prompt_constructor(inputs) |
|
image = inputs['image'] |
|
prompt = inputs['prompt'] |
|
data_samples = inputs['data_samples'] |
|
|
|
|
|
img_embeds, _ = self.encode_img(image) |
|
prompt_segs = prompt.split('<ImageHere>') |
|
prompt_seg_tokens = [ |
|
self.llama_tokenizer(seg, |
|
return_tensors='pt', |
|
add_special_tokens=i == 0). |
|
to(self.llama_model.model.embed_tokens.weight.device).input_ids |
|
for i, seg in enumerate(prompt_segs) |
|
] |
|
prompt_seg_embs = [ |
|
self.llama_model.model.embed_tokens(seg) |
|
for seg in prompt_seg_tokens |
|
] |
|
prompt_seg_embs = [prompt_seg_embs[0], img_embeds, prompt_seg_embs[1]] |
|
prompt_embs = torch.cat(prompt_seg_embs, dim=1) |
|
|
|
|
|
outputs = self.llama_model.generate( |
|
inputs_embeds=prompt_embs, |
|
max_length=self.max_length, |
|
num_beams=5, |
|
do_sample=self.do_sample, |
|
min_length=1, |
|
top_p=0.9, |
|
repetition_penalty=1.0, |
|
length_penalty=-1.0, |
|
temperature=1.0, |
|
stopping_criteria=self.stopping_criteria, |
|
num_return_sequences=1) |
|
|
|
for i, data_sample in enumerate(data_samples): |
|
output_token = outputs[i] |
|
output_text = self.post_processor(output_token, |
|
self.llama_tokenizer) |
|
if self.is_caption_task: |
|
data_sample.pred_caption = output_text |
|
else: |
|
data_sample.pred_answer = output_text |
|
data_samples[i] = data_sample |
|
return data_samples |
|
|
|
def loss(self, batch): |
|
inputs = self.pack_inputs(batch) |
|
inputs = self.prompt_constructor(inputs) |
|
image = inputs['image'] |
|
batch_size = image.size(0) |
|
prompt = inputs['prompt'] |
|
data_samples = inputs['data_samples'] |
|
choices = data_samples[0].choices |
|
|
|
with torch.no_grad(): |
|
img_embeds, atts_img = self.encode_img(image) |
|
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, |
|
prompt) |
|
|
|
self.llama_tokenizer.padding_side = 'right' |
|
|
|
n_cands = len(choices) |
|
losses = [] |
|
for n in range(self.n_segments): |
|
seg_len = n_cands // self.n_segments |
|
if n == (self.n_segments - 1): |
|
seg_len = n_cands - seg_len * (self.n_segments - 1) |
|
|
|
to_regress_tokens = self.llama_tokenizer( |
|
choices, |
|
return_tensors='pt', |
|
padding='longest', |
|
truncation=True, |
|
max_length=self.max_txt_len, |
|
add_special_tokens=False).to(image.device) |
|
|
|
targets = to_regress_tokens.input_ids.masked_fill( |
|
to_regress_tokens.input_ids == |
|
self.llama_tokenizer.pad_token_id, -100) |
|
|
|
empty_targets = ( |
|
torch.ones([atts_img.shape[0], atts_img.shape[1] + 1], |
|
dtype=torch.long).to(image.device).fill_( |
|
-100) |
|
) |
|
empty_targets = empty_targets.repeat_interleave(seg_len, dim=0) |
|
targets = torch.cat([empty_targets, targets], dim=1) |
|
|
|
bos = torch.ones([batch_size, 1], |
|
dtype=to_regress_tokens.input_ids.dtype, |
|
device=to_regress_tokens.input_ids.device |
|
) * self.llama_tokenizer.bos_token_id |
|
bos_embeds = self.llama_model.model.embed_tokens(bos) |
|
bos_embeds = bos_embeds.repeat_interleave(seg_len, dim=0) |
|
img_embeds = img_embeds.repeat_interleave(seg_len, dim=0) |
|
|
|
atts_bos = atts_img[:, :1] |
|
atts_bos = atts_bos.repeat_interleave(seg_len, dim=0) |
|
atts_img = atts_img.repeat_interleave(seg_len, dim=0) |
|
|
|
to_regress_embeds = self.llama_model.model.embed_tokens( |
|
to_regress_tokens.input_ids) |
|
|
|
inputs_embeds = torch.cat( |
|
[bos_embeds, img_embeds, to_regress_embeds], dim=1) |
|
attention_mask = torch.cat( |
|
[atts_bos, atts_img, to_regress_tokens.attention_mask], |
|
dim=1) |
|
|
|
with self.maybe_autocast(): |
|
outputs = self.llama_model( |
|
inputs_embeds=inputs_embeds, |
|
attention_mask=attention_mask, |
|
return_dict=True, |
|
labels=targets, |
|
reduction='none', |
|
) |
|
loss = outputs.loss |
|
loss = loss.view(targets.size(0), -1).sum(1) |
|
loss = loss.reshape(batch_size, seg_len) |
|
losses.append(loss) |
|
|
|
losses = torch.cat(losses, dim=-1)[0] |
|
|
|
for i, data_sample in enumerate(data_samples): |
|
data_sample.losses = losses |
|
data_samples[i] = data_sample |
|
return data_samples |
|
|