File size: 5,008 Bytes
6e32a75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import pdb

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence

import modules.utils as utils
from modules.caption_model import CaptionModel


class AttModel(CaptionModel):
    def __init__(self, args, tokenizer):
        super(AttModel, self).__init__()
        self.args = args
        self.tokenizer = tokenizer
        self.vocab_size = len(tokenizer.idx2token)
        self.max_seq_length = 60

    def _sample(self, clip_features, gpt_tokens,update_opts={}):

        opt = self.args.__dict__
        opt.update(**update_opts)
        sample_method = opt.get('sample_method', 'greedy')


        if sample_method == 'greedy':
            return self._greedy_sample(clip_features, gpt_tokens)
        elif sample_method == 'beam_search':
            return self._beam_search_sample(clip_features, gpt_tokens)
        else:
            raise ValueError("Unknown sample_method: " + sample_method)

    def _greedy_sample(self, clip_features, gpt_tokens, temperature=1.0):
        #input_ids = torch.full((clip_features.size(0), 1), self.tokenizer.bos_token_id).type_as(clip_features).long()
        clip_features = self.clip_project(clip_features).reshape(clip_features.size(0), 1, -1)
        tokens = [None for _ in range(clip_features.size(0))]
        finished = [False for _ in range(clip_features.size(0))]
        max_length = 200
        for _ in range(max_length):
            outputs = self.decoder(inputs_embeds= clip_features)
            logits = outputs.logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
            next_tokens = torch.argmax(logits, -1).unsqueeze(1)
            next_token_embeds = self.decoder.transformer.wte(next_tokens)
            for j in range(clip_features.size(0)):
                if finished[j]:
                    continue
                if tokens[j] is None:
                    tokens[j] = next_tokens[j]
                else:
                    tokens[j] = torch.cat((tokens[j], next_tokens[j]), dim=0)
                if next_tokens[j].item() == self.tokenizer.eos_token_id:
                    finished[j] = True
            clip_features = torch.cat((clip_features, next_token_embeds), dim=1)
        outputs = []
        for token in tokens:
            try:
                output_list = token.squeeze().cpu().numpy().tolist()
                # Pad or truncate output_list to max_length
                output_list = (output_list + [self.tokenizer.pad_token_id] * max_length)[:max_length]
            except Exception as e:
                print(f"Error during decoding: {type(e).__name__}: {e}")
                output_list = [self.tokenizer.pad_token_id] * max_length
            outputs.append(output_list)

        # Convert list of lists to tensor
        outputs = torch.tensor(outputs, device=clip_features.device)
        return outputs


    def _beam_search_sample(self, clip_features, gpt_tokens, beam_size=5):
        batch_size = clip_features.size(0)
        # Prepare the first input for every beam
        input_ids = torch.full((batch_size*beam_size, 1), self.tokenizer.bos_token_id).type_as(clip_features).long()
        beam_scores = torch.zeros((batch_size, beam_size)).type_as(clip_features)
        done = [False]*batch_size

        for _ in range(self.max_seq_length):
            outputs = self._forward(clip_features.repeat_interleave(beam_size, 0), input_ids)
            next_token_logits = outputs.logits[:, -1, :]
            next_token_probs = F.softmax(next_token_logits, dim=-1)

            # Apply a mask for already finished beams
            next_token_probs[done] = 0
            next_token_probs[:, self.tokenizer.eos_token_id] = -float('Inf')

            # Multiply old scores with new probabilities
            scores = beam_scores.unsqueeze(2) * next_token_probs
            scores = scores.view(batch_size, -1)

            # Get the top beam_size scores and their respective indices
            top_scores, top_indices = scores.topk(beam_size, dim=1)

            # Update beam scores
            beam_scores = top_scores.log()

            # Reshape input_ids
            input_ids = input_ids.view(batch_size, beam_size, -1)

            # Compute next inputs
            next_token_ids = top_indices % self.vocab_size
            beam_indices = top_indices // self.vocab_size
            next_input_ids = torch.cat([input_ids.gather(1, beam_indices.unsqueeze(2).expand(-1, -1, input_ids.size(2))), next_token_ids.unsqueeze(2)], dim=2)

            # Flatten input_ids
            input_ids = next_input_ids.view(batch_size*beam_size, -1)

            # Check which beams are done
            done = (next_token_ids == self.tokenizer.eos_token_id).all(dim=1).tolist()

            if all(done):
                break

        return input_ids.view(batch_size, beam_size, -1)