File size: 7,552 Bytes
7385f22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import torch
from typing import List, Tuple
from torch.nn import functional as F
from torch import distributed as tdist, nn as nn

from unitok import dist


def get_entropy_loss(latent_embed, codebook_embed, inv_entropy_tau):
    E_dist = latent_embed.square().sum(dim=1, keepdim=True) + codebook_embed.square().sum(dim=1, keepdim=False)
    E_dist.addmm_(latent_embed, codebook_embed.T, alpha=-2, beta=1)  # E_dist: (N, vocab_size)
    logits = -E_dist.float().mul_(inv_entropy_tau)
    # calc per_sample_entropy
    prob, log_prob = logits.softmax(dim=-1), logits.log_softmax(dim=-1)  # both are (N, vocab_size)
    per_sample_entropy = torch.mean((-prob * log_prob).sum(dim=-1))
    # calc codebook_entropy
    avg_prob = prob.mean(dim=0)  # (vocab_size,)
    log_avg_prob = torch.log(avg_prob + 1e-7)
    codebook_entropy = (-avg_prob * log_avg_prob).sum()
    # calc entropy_loss
    entropy_loss = per_sample_entropy - codebook_entropy
    return entropy_loss


class NormalizedEmbedding(nn.Embedding):
    def __init__(self, num_embeddings: int, embedding_dim: int):
        super().__init__(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
        # self.norm_scale = nn.Parameter(torch.tensor(0.0, dtype=torch.float32))

    def forward(self, idx):
        return F.embedding(
            idx, F.normalize(self.weight, dim=1), self.padding_idx, self.max_norm,
            self.norm_type, self.scale_grad_by_freq, self.sparse
        )

    def get_norm_weight(self):
        return F.normalize(self.weight, dim=1)


class ResConv(nn.Conv2d):
    def __init__(self, embed_dim, quant_resi):
        ks = 3 if quant_resi < 0 else 1
        super().__init__(in_channels=embed_dim, out_channels=embed_dim, kernel_size=ks, stride=1, padding=ks // 2)
        self.resi_ratio = abs(quant_resi)

    def forward(self, h_BChw):
        return h_BChw.mul(1 - self.resi_ratio) + super().forward(h_BChw).mul_(self.resi_ratio)


class VectorQuantizer(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        vocab_width: int,
        beta: float = 0.25,
        use_entropy_loss=False,
        entropy_temp=0.01,
    ):
        super().__init__()
        self.beta = beta
        self.vocab_size = vocab_size
        self.vocab_width = vocab_width
        self.vocab_usage_record_times: int = 0
        self.register_buffer('vocab_usage', torch.zeros(self.vocab_size))
        self.codebook = NormalizedEmbedding(self.vocab_size, self.vocab_width)

        self.use_entropy_loss = use_entropy_loss
        self.inv_entropy_tau = 1 / entropy_temp

    def init_vocab(self, eini: float):
        if eini > 0:
            nn.init.trunc_normal_(self.codebook.weight.data, std=eini)
        elif eini < 0:
            base = self.vocab_width ** -0.5
            base /= 36
            self.codebook.weight.data.uniform_(-abs(eini) * base, abs(eini) * base)

    def extra_repr(self) -> str:
        return f'beta={self.beta:g}'

    def forward(self, features):
        B, L, C = features.shape
        features = features.reshape(-1, C)
        features = F.normalize(features, dim=-1).float()
        codebook_embed = self.codebook.get_norm_weight()
        indices = torch.argmax(features.detach() @ codebook_embed.T, dim=1)
        entropy_loss = get_entropy_loss(features, codebook_embed, self.inv_entropy_tau) if self.use_entropy_loss else 0
        features_hat = self.codebook(indices)

        # calc loss
        vq_loss = F.mse_loss(features_hat.detach(), features).mul_(self.beta) + F.mse_loss(features_hat,
                                                                                           features.detach())
        features_hat = (features_hat.detach() - features.detach()).add_(features)

        # update vocab_usage
        prob_per_class_is_chosen = indices.bincount(minlength=self.vocab_size).float()
        handler = tdist.all_reduce(prob_per_class_is_chosen, async_op=True) if (
                self.training and dist.initialized()) else None
        if handler is not None:
            handler.wait()
        prob_per_class_is_chosen /= prob_per_class_is_chosen.sum()
        vocab_usage = (prob_per_class_is_chosen > 0.01 / self.vocab_size).float().mean().mul_(100)
        if self.vocab_usage_record_times == 0:
            self.vocab_usage.copy_(prob_per_class_is_chosen)
        elif self.vocab_usage_record_times < 100:
            self.vocab_usage.mul_(0.9).add_(prob_per_class_is_chosen, alpha=0.1)
        else:
            self.vocab_usage.mul_(0.99).add_(prob_per_class_is_chosen, alpha=0.01)
        self.vocab_usage_record_times += 1

        return features_hat.view(B, L, C), vq_loss, entropy_loss, vocab_usage

    def f_to_idx(self, features):
        B, L, C = features.shape
        features = features.reshape(-1, C)
        features = F.normalize(features, dim=-1).float()
        codebook_embed = self.codebook.get_norm_weight().float()
        indices = torch.argmax(features.detach() @ codebook_embed.T, dim=1)
        return indices.view(B, L)


class VectorQuantizerM(nn.Module):
    def __init__(
        self,
        vocab_size,
        vocab_width,
        beta=0.25,
        use_entropy_loss=False,
        entropy_temp=0.01,
        num_codebooks=16
    ):
        super().__init__()
        self.num_codebooks = num_codebooks
        self.codebooks = nn.ModuleList()
        for _ in range(num_codebooks):
            codebook = VectorQuantizer(
                vocab_size=vocab_size // num_codebooks,
                vocab_width=vocab_width // num_codebooks,
                beta=beta,
                use_entropy_loss=use_entropy_loss,
                entropy_temp=entropy_temp,
            )
            self.codebooks.append(codebook)

    def init_vocab(self, eini: float):
        for codebook in self.codebooks:
            codebook.init_vocab(eini)

    def f_to_idx(self, features):
        indices = []
        chunk_size = features.shape[-1] // self.num_codebooks
        splited_features = features.split(chunk_size, dim=-1)
        for i, codebook in enumerate(self.codebooks):
            indices.append(codebook.f_to_idx(splited_features[i]))
        indices = torch.stack(indices, dim=1)
        return indices

    def idx_to_f(self, indices):
        assert indices.shape[1] == self.num_codebooks
        latent_features = []
        for i, codebook in enumerate(self.codebooks):
            sub_indices = indices[:, i].flatten(start_dim=1)
            latent_feature = codebook.codebook(sub_indices)
            latent_features.append(latent_feature)
        latent_features = torch.cat(latent_features, dim=-1)
        return latent_features

    def forward(self, features):
        latent_features = []
        global_vq_loss = 0.
        global_entropy_loss = 0.
        global_vocab_usage = 0.
        chunk_size = features.shape[-1] // self.num_codebooks
        splited_features = features.split(chunk_size, dim=-1)
        for i, codebook in enumerate(self.codebooks):
            latent_feature, vq_loss, entropy_loss, vocab_usage = codebook(splited_features[i])
            latent_features.append(latent_feature)
            global_vq_loss += vq_loss
            global_entropy_loss += entropy_loss
            global_vocab_usage += vocab_usage
        latent_features = torch.cat(latent_features, dim=-1)
        global_entropy_loss /= self.num_codebooks
        global_vq_loss /= self.num_codebooks
        global_vocab_usage /= self.num_codebooks
        return latent_features, global_vq_loss, global_entropy_loss, global_vocab_usage