Spaces:
Sleeping
Sleeping
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence | |
import modules.utils as utils | |
from modules.caption_model import CaptionModel | |
def sort_pack_padded_sequence(input, lengths): | |
sorted_lengths, indices = torch.sort(lengths, descending=True) | |
tmp = pack_padded_sequence(input[indices], sorted_lengths, batch_first=True) | |
inv_ix = indices.clone() | |
inv_ix[indices] = torch.arange(0, len(indices)).type_as(inv_ix) | |
return tmp, inv_ix | |
def pad_unsort_packed_sequence(input, inv_ix): | |
tmp, _ = pad_packed_sequence(input, batch_first=True) | |
tmp = tmp[inv_ix] | |
return tmp | |
def pack_wrapper(module, att_feats, att_masks): | |
if att_masks is not None: | |
packed, inv_ix = sort_pack_padded_sequence(att_feats, att_masks.data.long().sum(1)) | |
return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix) | |
else: | |
return module(att_feats) | |
class AttModel(CaptionModel): | |
def __init__(self, args, tokenizer): | |
super(AttModel, self).__init__() | |
self.args = args | |
self.tokenizer = tokenizer | |
self.vocab_size = len(tokenizer.idx2token) | |
self.input_encoding_size = args.d_model | |
self.rnn_size = args.d_ff | |
self.num_layers = args.num_layers | |
self.drop_prob_lm = args.drop_prob_lm | |
self.max_seq_length = args.max_seq_length | |
self.att_feat_size = args.d_vf | |
self.att_hid_size = args.d_model | |
self.bos_idx = args.bos_idx | |
self.eos_idx = args.eos_idx | |
self.pad_idx = args.pad_idx | |
self.use_bn = args.use_bn | |
self.embed = lambda x: x | |
self.fc_embed = lambda x: x | |
self.att_embed = nn.Sequential(*( | |
((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ()) + | |
(nn.Linear(self.att_feat_size, self.input_encoding_size), | |
nn.ReLU(), | |
nn.Dropout(self.drop_prob_lm)) + | |
((nn.BatchNorm1d(self.input_encoding_size),) if self.use_bn == 2 else ()))) | |
def clip_att(self, att_feats, att_masks): | |
# Clip the length of att_masks and att_feats to the maximum length | |
if att_masks is not None: | |
max_len = att_masks.data.long().sum(1).max() | |
att_feats = att_feats[:, :max_len].contiguous() | |
att_masks = att_masks[:, :max_len].contiguous() | |
return att_feats, att_masks | |
def _prepare_feature(self, fc_feats, att_feats, att_masks): | |
att_feats, att_masks = self.clip_att(att_feats, att_masks) | |
# embed fc and att feats | |
fc_feats = self.fc_embed(fc_feats) | |
att_feats = pack_wrapper(self.att_embed, att_feats, att_masks) | |
# Project the attention feats first to reduce memory and computation comsumptions. | |
p_att_feats = self.ctx2att(att_feats) | |
return fc_feats, att_feats, p_att_feats, att_masks | |
def get_logprobs_state(self, it, fc_feats, att_feats, p_att_feats, att_masks, state, output_logsoftmax=1): | |
# 'it' contains a word index | |
xt = self.embed(it) | |
output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state, att_masks) | |
if output_logsoftmax: | |
logprobs = F.log_softmax(self.logit(output), dim=1) | |
else: | |
logprobs = self.logit(output) | |
return logprobs, state | |
def _sample_beam(self, fc_feats, att_feats, att_masks=None, opt={}): | |
beam_size = opt.get('beam_size', 10) | |
group_size = opt.get('group_size', 1) | |
sample_n = opt.get('sample_n', 10) | |
# when sample_n == beam_size then each beam is a sample. | |
assert sample_n == 1 or sample_n == beam_size // group_size, 'when beam search, sample_n == 1 or beam search' | |
batch_size = fc_feats.size(0) | |
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) | |
assert beam_size <= self.vocab_size + 1, 'lets assume this for now, otherwise this corner case causes a few headaches down the road. can be dealt with in future if needed' | |
seq = fc_feats.new_full((batch_size * sample_n, self.max_seq_length), self.pad_idx, dtype=torch.long) | |
seqLogprobs = fc_feats.new_zeros(batch_size * sample_n, self.max_seq_length, self.vocab_size + 1) | |
# lets process every image independently for now, for simplicity | |
self.done_beams = [[] for _ in range(batch_size)] | |
state = self.init_hidden(batch_size) | |
# first step, feed bos | |
it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long) | |
logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state) | |
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(beam_size, | |
[p_fc_feats, p_att_feats, | |
pp_att_feats, p_att_masks] | |
) | |
self.done_beams = self.beam_search(state, logprobs, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, opt=opt) | |
for k in range(batch_size): | |
if sample_n == beam_size: | |
for _n in range(sample_n): | |
seq_len = self.done_beams[k][_n]['seq'].shape[0] | |
seq[k * sample_n + _n, :seq_len] = self.done_beams[k][_n]['seq'] | |
seqLogprobs[k * sample_n + _n, :seq_len] = self.done_beams[k][_n]['logps'] | |
else: | |
seq_len = self.done_beams[k][0]['seq'].shape[0] | |
seq[k, :seq_len] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score | |
seqLogprobs[k, :seq_len] = self.done_beams[k][0]['logps'] | |
# return the samples and their log likelihoods | |
return seq, seqLogprobs | |
def _sample(self, fc_feats, att_feats, att_masks=None): | |
opt = self.args.__dict__ | |
sample_method = opt.get('sample_method', 'greedy') | |
beam_size = opt.get('beam_size', 1) | |
temperature = opt.get('temperature', 1.0) | |
sample_n = int(opt.get('sample_n', 1)) | |
group_size = opt.get('group_size', 1) | |
output_logsoftmax = opt.get('output_logsoftmax', 1) | |
decoding_constraint = opt.get('decoding_constraint', 0) | |
block_trigrams = opt.get('block_trigrams', 0) | |
if beam_size > 1 and sample_method in ['greedy', 'beam_search']: | |
return self._sample_beam(fc_feats, att_feats, att_masks, opt) | |
if group_size > 1: | |
return self._diverse_sample(fc_feats, att_feats, att_masks, opt) | |
batch_size = fc_feats.size(0) | |
state = self.init_hidden(batch_size * sample_n) | |
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) | |
if sample_n > 1: | |
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = utils.repeat_tensors(sample_n, | |
[p_fc_feats, p_att_feats, | |
pp_att_feats, p_att_masks] | |
) | |
trigrams = [] # will be a list of batch_size dictionaries | |
seq = fc_feats.new_full((batch_size * sample_n, self.max_seq_length), self.pad_idx, dtype=torch.long) | |
seqLogprobs = fc_feats.new_zeros(batch_size * sample_n, self.max_seq_length, self.vocab_size + 1) | |
for t in range(self.max_seq_length + 1): | |
if t == 0: # input <bos> | |
it = fc_feats.new_full([batch_size * sample_n], self.bos_idx, dtype=torch.long) | |
logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state, | |
output_logsoftmax=output_logsoftmax) | |
if decoding_constraint and t > 0: | |
tmp = logprobs.new_zeros(logprobs.size()) | |
tmp.scatter_(1, seq[:, t - 1].data.unsqueeze(1), float('-inf')) | |
logprobs = logprobs + tmp | |
# Mess with trigrams | |
# Copy from https://github.com/lukemelas/image-paragraph-captioning | |
if block_trigrams and t >= 3: | |
# Store trigram generated at last step | |
prev_two_batch = seq[:, t - 3:t - 1] | |
for i in range(batch_size): # = seq.size(0) | |
prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) | |
current = seq[i][t - 1] | |
if t == 3: # initialize | |
trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int} | |
elif t > 3: | |
if prev_two in trigrams[i]: # add to list | |
trigrams[i][prev_two].append(current) | |
else: # create list | |
trigrams[i][prev_two] = [current] | |
# Block used trigrams at next step | |
prev_two_batch = seq[:, t - 2:t] | |
mask = torch.zeros(logprobs.size(), requires_grad=False).cuda() # batch_size x vocab_size | |
for i in range(batch_size): | |
prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) | |
if prev_two in trigrams[i]: | |
for j in trigrams[i][prev_two]: | |
mask[i, j] += 1 | |
# Apply mask to log probs | |
# logprobs = logprobs - (mask * 1e9) | |
alpha = 2.0 # = 4 | |
logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best) | |
# sample the next word | |
if t == self.max_seq_length: # skip if we achieve maximum length | |
break | |
it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, temperature) | |
# stop when all finished | |
if t == 0: | |
unfinished = it != self.eos_idx | |
else: | |
it[~unfinished] = self.pad_idx # This allows eos_idx not being overwritten to 0 | |
logprobs = logprobs * unfinished.unsqueeze(1).float() | |
unfinished = unfinished * (it != self.eos_idx) | |
seq[:, t] = it | |
seqLogprobs[:, t] = logprobs | |
# quit loop if all sequences have finished | |
if unfinished.sum() == 0: | |
break | |
return seq, seqLogprobs | |
def _diverse_sample(self, fc_feats, att_feats, att_masks=None, opt={}): | |
sample_method = opt.get('sample_method', 'greedy') | |
beam_size = opt.get('beam_size', 1) | |
temperature = opt.get('temperature', 1.0) | |
group_size = opt.get('group_size', 1) | |
diversity_lambda = opt.get('diversity_lambda', 0.5) | |
decoding_constraint = opt.get('decoding_constraint', 0) | |
block_trigrams = opt.get('block_trigrams', 0) | |
batch_size = fc_feats.size(0) | |
state = self.init_hidden(batch_size) | |
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks) | |
trigrams_table = [[] for _ in range(group_size)] # will be a list of batch_size dictionaries | |
seq_table = [fc_feats.new_full((batch_size, self.max_seq_length), self.pad_idx, dtype=torch.long) for _ in | |
range(group_size)] | |
seqLogprobs_table = [fc_feats.new_zeros(batch_size, self.max_seq_length) for _ in range(group_size)] | |
state_table = [self.init_hidden(batch_size) for _ in range(group_size)] | |
for tt in range(self.max_seq_length + group_size): | |
for divm in range(group_size): | |
t = tt - divm | |
seq = seq_table[divm] | |
seqLogprobs = seqLogprobs_table[divm] | |
trigrams = trigrams_table[divm] | |
if t >= 0 and t <= self.max_seq_length - 1: | |
if t == 0: # input <bos> | |
it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long) | |
else: | |
it = seq[:, t - 1] # changed | |
logprobs, state_table[divm] = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, | |
p_att_masks, state_table[divm]) # changed | |
logprobs = F.log_softmax(logprobs / temperature, dim=-1) | |
# Add diversity | |
if divm > 0: | |
unaug_logprobs = logprobs.clone() | |
for prev_choice in range(divm): | |
prev_decisions = seq_table[prev_choice][:, t] | |
logprobs[:, prev_decisions] = logprobs[:, prev_decisions] - diversity_lambda | |
if decoding_constraint and t > 0: | |
tmp = logprobs.new_zeros(logprobs.size()) | |
tmp.scatter_(1, seq[:, t - 1].data.unsqueeze(1), float('-inf')) | |
logprobs = logprobs + tmp | |
# Mess with trigrams | |
if block_trigrams and t >= 3: | |
# Store trigram generated at last step | |
prev_two_batch = seq[:, t - 3:t - 1] | |
for i in range(batch_size): # = seq.size(0) | |
prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) | |
current = seq[i][t - 1] | |
if t == 3: # initialize | |
trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int} | |
elif t > 3: | |
if prev_two in trigrams[i]: # add to list | |
trigrams[i][prev_two].append(current) | |
else: # create list | |
trigrams[i][prev_two] = [current] | |
# Block used trigrams at next step | |
prev_two_batch = seq[:, t - 2:t] | |
mask = torch.zeros(logprobs.size(), requires_grad=False).cuda() # batch_size x vocab_size | |
for i in range(batch_size): | |
prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item()) | |
if prev_two in trigrams[i]: | |
for j in trigrams[i][prev_two]: | |
mask[i, j] += 1 | |
# Apply mask to log probs | |
# logprobs = logprobs - (mask * 1e9) | |
alpha = 2.0 # = 4 | |
logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best) | |
it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, 1) | |
# stop when all finished | |
if t == 0: | |
unfinished = it != self.eos_idx | |
else: | |
unfinished = seq[:, t - 1] != self.pad_idx & seq[:, t - 1] != self.eos_idx | |
it[~unfinished] = self.pad_idx | |
unfinished = unfinished & (it != self.eos_idx) # changed | |
seq[:, t] = it | |
seqLogprobs[:, t] = sampleLogprobs.view(-1) | |
return torch.stack(seq_table, 1).reshape(batch_size * group_size, -1), torch.stack(seqLogprobs_table, | |
1).reshape( | |
batch_size * group_size, -1) | |