Spaces:
Sleeping
Sleeping
import torch | |
import math | |
from typing import Dict, List, Optional | |
from fairseq.sequence_generator import SequenceGenerator | |
from torch import Tensor | |
class TextRecognitionGenerator(SequenceGenerator): | |
def _generate( | |
self, | |
sample: Dict[str, Dict[str, Tensor]], | |
prefix_tokens: Optional[Tensor] = None, | |
constraints: Optional[Tensor] = None, | |
bos_token: Optional[int] = None, | |
): | |
incremental_states = torch.jit.annotate( | |
List[Dict[str, Dict[str, Optional[Tensor]]]], | |
[ | |
torch.jit.annotate(Dict[str, Dict[str, Optional[Tensor]]], {}) | |
for i in range(self.model.models_size) | |
], | |
) | |
net_input = sample["net_input"] | |
device = sample["net_input"]["imgs"].device | |
# compute the encoder output for each beam | |
# "encoder_out": [x], # T x B x C | |
# "encoder_padding_mask": [encoder_padding_mask], # B x T | |
# "encoder_embedding": [encoder_embedding], # B x T x C | |
# "encoder_states": [], # List[T x B x C] | |
# "src_tokens": [], | |
# "src_lengths": [], | |
encoder_outs = self.model.forward_encoder(net_input) # T x B x C | |
src_lengths = encoder_outs[0]['encoder_padding_mask'][0].eq(0).long().sum(dim=1) # B | |
src_tokens = encoder_outs[0]['encoder_padding_mask'][0] # B x T | |
# bsz: total number of sentences in beam | |
# Note that src_tokens may have more than 2 dimensions (i.e. audio features) | |
bsz, src_len = src_tokens.size()[:2] | |
beam_size = self.beam_size | |
if constraints is not None and not self.search.supports_constraints: | |
raise NotImplementedError( | |
"Target-side constraints were provided, but search method doesn't support them" | |
) | |
# Initialize constraints, when active | |
self.search.init_constraints(constraints, beam_size) | |
max_len: int = -1 | |
if self.match_source_len: | |
max_len = src_lengths.max().item() | |
else: | |
max_len = min( | |
int(self.max_len_a * src_len + self.max_len_b), | |
# exclude the EOS marker | |
self.model.max_decoder_positions() - 1, | |
) | |
assert ( | |
self.min_len <= max_len | |
), "min_len cannot be larger than max_len, please adjust these!" | |
# placeholder of indices for bsz * beam_size to hold tokens and accumulative scores | |
new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1) | |
new_order = new_order.to(src_tokens.device).long() | |
encoder_outs = self.model.reorder_encoder_out(encoder_outs, new_order) | |
# ensure encoder_outs is a List. | |
assert encoder_outs is not None | |
# initialize buffers | |
scores = ( | |
torch.zeros(bsz * beam_size, max_len + 1).to(src_tokens).float() | |
) # +1 for eos; pad is never chosen for scoring | |
tokens = ( | |
torch.zeros(bsz * beam_size, max_len + 2) | |
.to(src_tokens) | |
.long() | |
.fill_(self.pad) | |
) # +2 for eos and pad | |
tokens[:, 0] = self.eos if bos_token is None else bos_token | |
attn: Optional[Tensor] = None | |
# A list that indicates candidates that should be ignored. | |
# For example, suppose we're sampling and have already finalized 2/5 | |
# samples. Then cands_to_ignore would mark 2 positions as being ignored, | |
# so that we only finalize the remaining 3 samples. | |
cands_to_ignore = ( | |
torch.zeros(bsz, beam_size).to(src_tokens).eq(-1) | |
) # forward and backward-compatible False mask | |
# list of completed sentences | |
finalized = torch.jit.annotate( | |
List[List[Dict[str, Tensor]]], | |
[torch.jit.annotate(List[Dict[str, Tensor]], []) for i in range(bsz)], | |
) # contains lists of dictionaries of infomation about the hypothesis being finalized at each step | |
finished = [ | |
False for i in range(bsz) | |
] # a boolean array indicating if the sentence at the index is finished or not | |
num_remaining_sent = bsz # number of sentences remaining | |
# number of candidate hypos per step | |
cand_size = 2 * beam_size # 2 x beam size in case half are EOS | |
# offset arrays for converting between different indexing schemes | |
bbsz_offsets = ( | |
(torch.arange(0, bsz) * beam_size) | |
.unsqueeze(1) | |
.type_as(tokens) | |
.to(src_tokens.device) | |
) | |
cand_offsets = torch.arange(0, cand_size).type_as(tokens).to(src_tokens.device) | |
reorder_state: Optional[Tensor] = None | |
batch_idxs: Optional[Tensor] = None | |
original_batch_idxs: Optional[Tensor] = None | |
if "id" in sample and isinstance(sample["id"], Tensor): | |
original_batch_idxs = sample["id"] | |
else: | |
original_batch_idxs = torch.arange(0, bsz).type_as(tokens) | |
for step in range(max_len + 1): # one extra step for EOS marker | |
# reorder decoder internal states based on the prev choice of beams | |
if reorder_state is not None: | |
if batch_idxs is not None: | |
# update beam indices to take into account removed sentences | |
corr = batch_idxs - torch.arange(batch_idxs.numel()).type_as( | |
batch_idxs | |
) | |
reorder_state.view(-1, beam_size).add_( | |
corr.unsqueeze(-1) * beam_size | |
) | |
original_batch_idxs = original_batch_idxs[batch_idxs] | |
self.model.reorder_incremental_state(incremental_states, reorder_state) | |
encoder_outs = self.model.reorder_encoder_out( | |
encoder_outs, reorder_state | |
) | |
lprobs, avg_attn_scores = self.model.forward_decoder( | |
tokens[:, : step + 1], | |
encoder_outs, | |
incremental_states, | |
self.temperature, | |
) | |
if self.lm_model is not None: | |
lm_out = self.lm_model(tokens[:, : step + 1]) | |
probs = self.lm_model.get_normalized_probs( | |
lm_out, log_probs=True, sample=None | |
) | |
probs = probs[:, -1, :] * self.lm_weight | |
lprobs += probs | |
lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs) | |
lprobs[:, self.pad] = -math.inf # never select pad | |
lprobs[:, self.unk] -= self.unk_penalty # apply unk penalty | |
# handle max length constraint | |
if step >= max_len: | |
lprobs[:, : self.eos] = -math.inf | |
lprobs[:, self.eos + 1 :] = -math.inf | |
# handle prefix tokens (possibly with different lengths) | |
if ( | |
prefix_tokens is not None | |
and step < prefix_tokens.size(1) | |
and step < max_len | |
): | |
lprobs, tokens, scores = self._prefix_tokens( | |
step, lprobs, scores, tokens, prefix_tokens, beam_size | |
) | |
elif step < self.min_len: | |
# minimum length constraint (does not apply if using prefix_tokens) | |
lprobs[:, self.eos] = -math.inf | |
# Record attention scores, only support avg_attn_scores is a Tensor | |
if avg_attn_scores is not None: | |
if attn is None: | |
attn = torch.empty( | |
bsz * beam_size, avg_attn_scores.size(1), max_len + 2 | |
).to(scores) | |
attn[:, :, step + 1].copy_(avg_attn_scores) | |
scores = scores.type_as(lprobs) | |
eos_bbsz_idx = torch.empty(0).to( | |
tokens | |
) # indices of hypothesis ending with eos (finished sentences) | |
eos_scores = torch.empty(0).to( | |
scores | |
) # scores of hypothesis ending with eos (finished sentences) | |
if self.should_set_src_lengths: | |
self.search.set_src_lengths(src_lengths) | |
if self.repeat_ngram_blocker is not None: | |
lprobs = self.repeat_ngram_blocker(tokens, lprobs, bsz, beam_size, step) | |
# Shape: (batch, cand_size) | |
cand_scores, cand_indices, cand_beams = self.search.step( | |
step, | |
lprobs.view(bsz, -1, self.vocab_size), | |
scores.view(bsz, beam_size, -1)[:, :, :step], | |
tokens[:, : step + 1], | |
original_batch_idxs, | |
) | |
# cand_bbsz_idx contains beam indices for the top candidate | |
# hypotheses, with a range of values: [0, bsz*beam_size), | |
# and dimensions: [bsz, cand_size] | |
cand_bbsz_idx = cand_beams.add(bbsz_offsets) | |
# finalize hypotheses that end in eos | |
# Shape of eos_mask: (batch size, beam size) | |
eos_mask = cand_indices.eq(self.eos) & cand_scores.ne(-math.inf) | |
eos_mask[:, :beam_size][cands_to_ignore] = torch.tensor(0).to(eos_mask) | |
# only consider eos when it's among the top beam_size indices | |
# Now we know what beam item(s) to finish | |
# Shape: 1d list of absolute-numbered | |
eos_bbsz_idx = torch.masked_select( | |
cand_bbsz_idx[:, :beam_size], mask=eos_mask[:, :beam_size] | |
) | |
finalized_sents: List[int] = [] | |
if eos_bbsz_idx.numel() > 0: | |
eos_scores = torch.masked_select( | |
cand_scores[:, :beam_size], mask=eos_mask[:, :beam_size] | |
) | |
finalized_sents = self.finalize_hypos( | |
step, | |
eos_bbsz_idx, | |
eos_scores, | |
tokens, | |
scores, | |
finalized, | |
finished, | |
beam_size, | |
attn, | |
src_lengths, | |
max_len, | |
) | |
num_remaining_sent -= len(finalized_sents) | |
assert num_remaining_sent >= 0 | |
if num_remaining_sent == 0: | |
break | |
if self.search.stop_on_max_len and step >= max_len: | |
break | |
assert step < max_len, f"{step} < {max_len}" | |
# Remove finalized sentences (ones for which {beam_size} | |
# finished hypotheses have been generated) from the batch. | |
if len(finalized_sents) > 0: | |
new_bsz = bsz - len(finalized_sents) | |
# construct batch_idxs which holds indices of batches to keep for the next pass | |
batch_mask = torch.ones( | |
bsz, dtype=torch.bool, device=cand_indices.device | |
) | |
batch_mask[finalized_sents] = False | |
# TODO replace `nonzero(as_tuple=False)` after TorchScript supports it | |
batch_idxs = torch.arange( | |
bsz, device=cand_indices.device | |
).masked_select(batch_mask) | |
# Choose the subset of the hypothesized constraints that will continue | |
self.search.prune_sentences(batch_idxs) | |
eos_mask = eos_mask[batch_idxs] | |
cand_beams = cand_beams[batch_idxs] | |
bbsz_offsets.resize_(new_bsz, 1) | |
cand_bbsz_idx = cand_beams.add(bbsz_offsets) | |
cand_scores = cand_scores[batch_idxs] | |
cand_indices = cand_indices[batch_idxs] | |
if prefix_tokens is not None: | |
prefix_tokens = prefix_tokens[batch_idxs] | |
src_lengths = src_lengths[batch_idxs] | |
cands_to_ignore = cands_to_ignore[batch_idxs] | |
scores = scores.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) | |
tokens = tokens.view(bsz, -1)[batch_idxs].view(new_bsz * beam_size, -1) | |
if attn is not None: | |
attn = attn.view(bsz, -1)[batch_idxs].view( | |
new_bsz * beam_size, attn.size(1), -1 | |
) | |
bsz = new_bsz | |
else: | |
batch_idxs = None | |
# Set active_mask so that values > cand_size indicate eos hypos | |
# and values < cand_size indicate candidate active hypos. | |
# After, the min values per row are the top candidate active hypos | |
# Rewrite the operator since the element wise or is not supported in torchscript. | |
eos_mask[:, :beam_size] = ~((~cands_to_ignore) & (~eos_mask[:, :beam_size])) | |
active_mask = torch.add( | |
eos_mask.type_as(cand_offsets) * cand_size, | |
cand_offsets[: eos_mask.size(1)], | |
) | |
# get the top beam_size active hypotheses, which are just | |
# the hypos with the smallest values in active_mask. | |
# {active_hypos} indicates which {beam_size} hypotheses | |
# from the list of {2 * beam_size} candidates were | |
# selected. Shapes: (batch size, beam size) | |
new_cands_to_ignore, active_hypos = torch.topk( | |
active_mask, k=beam_size, dim=1, largest=False | |
) | |
# update cands_to_ignore to ignore any finalized hypos. | |
cands_to_ignore = new_cands_to_ignore.ge(cand_size)[:, :beam_size] | |
# Make sure there is at least one active item for each sentence in the batch. | |
assert (~cands_to_ignore).any(dim=1).all() | |
# update cands_to_ignore to ignore any finalized hypos | |
# {active_bbsz_idx} denotes which beam number is continued for each new hypothesis (a beam | |
# can be selected more than once). | |
active_bbsz_idx = torch.gather(cand_bbsz_idx, dim=1, index=active_hypos) | |
active_scores = torch.gather(cand_scores, dim=1, index=active_hypos) | |
active_bbsz_idx = active_bbsz_idx.view(-1) | |
active_scores = active_scores.view(-1) | |
# copy tokens and scores for active hypotheses | |
# Set the tokens for each beam (can select the same row more than once) | |
tokens[:, : step + 1] = torch.index_select( | |
tokens[:, : step + 1], dim=0, index=active_bbsz_idx | |
) | |
# Select the next token for each of them | |
tokens.view(bsz, beam_size, -1)[:, :, step + 1] = torch.gather( | |
cand_indices, dim=1, index=active_hypos | |
) | |
if step > 0: | |
scores[:, :step] = torch.index_select( | |
scores[:, :step], dim=0, index=active_bbsz_idx | |
) | |
scores.view(bsz, beam_size, -1)[:, :, step] = torch.gather( | |
cand_scores, dim=1, index=active_hypos | |
) | |
# Update constraints based on which candidates were selected for the next beam | |
self.search.update_constraints(active_hypos) | |
# copy attention for active hypotheses | |
if attn is not None: | |
attn[:, :, : step + 2] = torch.index_select( | |
attn[:, :, : step + 2], dim=0, index=active_bbsz_idx | |
) | |
# reorder incremental state in decoder | |
reorder_state = active_bbsz_idx | |
# sort by score descending | |
for sent in range(len(finalized)): | |
scores = torch.tensor( | |
[float(elem["score"].item()) for elem in finalized[sent]] | |
) | |
_, sorted_scores_indices = torch.sort(scores, descending=True) | |
finalized[sent] = [finalized[sent][ssi] for ssi in sorted_scores_indices] | |
finalized[sent] = torch.jit.annotate( | |
List[Dict[str, Tensor]], finalized[sent] | |
) | |
return finalized | |