PromptNet / modules /att_models.py
fenglinliu's picture
Upload 55 files
6e32a75 verified
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import pdb
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
import modules.utils as utils
from modules.caption_model import CaptionModel
class AttModel(CaptionModel):
def __init__(self, args, tokenizer):
super(AttModel, self).__init__()
self.args = args
self.tokenizer = tokenizer
self.vocab_size = len(tokenizer.idx2token)
self.max_seq_length = 60
def _sample(self, clip_features, gpt_tokens,update_opts={}):
opt = self.args.__dict__
opt.update(**update_opts)
sample_method = opt.get('sample_method', 'greedy')
if sample_method == 'greedy':
return self._greedy_sample(clip_features, gpt_tokens)
elif sample_method == 'beam_search':
return self._beam_search_sample(clip_features, gpt_tokens)
else:
raise ValueError("Unknown sample_method: " + sample_method)
def _greedy_sample(self, clip_features, gpt_tokens, temperature=1.0):
#input_ids = torch.full((clip_features.size(0), 1), self.tokenizer.bos_token_id).type_as(clip_features).long()
clip_features = self.clip_project(clip_features).reshape(clip_features.size(0), 1, -1)
tokens = [None for _ in range(clip_features.size(0))]
finished = [False for _ in range(clip_features.size(0))]
max_length = 200
for _ in range(max_length):
outputs = self.decoder(inputs_embeds= clip_features)
logits = outputs.logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
next_tokens = torch.argmax(logits, -1).unsqueeze(1)
next_token_embeds = self.decoder.transformer.wte(next_tokens)
for j in range(clip_features.size(0)):
if finished[j]:
continue
if tokens[j] is None:
tokens[j] = next_tokens[j]
else:
tokens[j] = torch.cat((tokens[j], next_tokens[j]), dim=0)
if next_tokens[j].item() == self.tokenizer.eos_token_id:
finished[j] = True
clip_features = torch.cat((clip_features, next_token_embeds), dim=1)
outputs = []
for token in tokens:
try:
output_list = token.squeeze().cpu().numpy().tolist()
# Pad or truncate output_list to max_length
output_list = (output_list + [self.tokenizer.pad_token_id] * max_length)[:max_length]
except Exception as e:
print(f"Error during decoding: {type(e).__name__}: {e}")
output_list = [self.tokenizer.pad_token_id] * max_length
outputs.append(output_list)
# Convert list of lists to tensor
outputs = torch.tensor(outputs, device=clip_features.device)
return outputs
def _beam_search_sample(self, clip_features, gpt_tokens, beam_size=5):
batch_size = clip_features.size(0)
# Prepare the first input for every beam
input_ids = torch.full((batch_size*beam_size, 1), self.tokenizer.bos_token_id).type_as(clip_features).long()
beam_scores = torch.zeros((batch_size, beam_size)).type_as(clip_features)
done = [False]*batch_size
for _ in range(self.max_seq_length):
outputs = self._forward(clip_features.repeat_interleave(beam_size, 0), input_ids)
next_token_logits = outputs.logits[:, -1, :]
next_token_probs = F.softmax(next_token_logits, dim=-1)
# Apply a mask for already finished beams
next_token_probs[done] = 0
next_token_probs[:, self.tokenizer.eos_token_id] = -float('Inf')
# Multiply old scores with new probabilities
scores = beam_scores.unsqueeze(2) * next_token_probs
scores = scores.view(batch_size, -1)
# Get the top beam_size scores and their respective indices
top_scores, top_indices = scores.topk(beam_size, dim=1)
# Update beam scores
beam_scores = top_scores.log()
# Reshape input_ids
input_ids = input_ids.view(batch_size, beam_size, -1)
# Compute next inputs
next_token_ids = top_indices % self.vocab_size
beam_indices = top_indices // self.vocab_size
next_input_ids = torch.cat([input_ids.gather(1, beam_indices.unsqueeze(2).expand(-1, -1, input_ids.size(2))), next_token_ids.unsqueeze(2)], dim=2)
# Flatten input_ids
input_ids = next_input_ids.view(batch_size*beam_size, -1)
# Check which beams are done
done = (next_token_ids == self.tokenizer.eos_token_id).all(dim=1).tolist()
if all(done):
break
return input_ids.view(batch_size, beam_size, -1)