Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,552 Bytes
7385f22 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import torch
from typing import List, Tuple
from torch.nn import functional as F
from torch import distributed as tdist, nn as nn
from unitok import dist
def get_entropy_loss(latent_embed, codebook_embed, inv_entropy_tau):
E_dist = latent_embed.square().sum(dim=1, keepdim=True) + codebook_embed.square().sum(dim=1, keepdim=False)
E_dist.addmm_(latent_embed, codebook_embed.T, alpha=-2, beta=1) # E_dist: (N, vocab_size)
logits = -E_dist.float().mul_(inv_entropy_tau)
# calc per_sample_entropy
prob, log_prob = logits.softmax(dim=-1), logits.log_softmax(dim=-1) # both are (N, vocab_size)
per_sample_entropy = torch.mean((-prob * log_prob).sum(dim=-1))
# calc codebook_entropy
avg_prob = prob.mean(dim=0) # (vocab_size,)
log_avg_prob = torch.log(avg_prob + 1e-7)
codebook_entropy = (-avg_prob * log_avg_prob).sum()
# calc entropy_loss
entropy_loss = per_sample_entropy - codebook_entropy
return entropy_loss
class NormalizedEmbedding(nn.Embedding):
def __init__(self, num_embeddings: int, embedding_dim: int):
super().__init__(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
# self.norm_scale = nn.Parameter(torch.tensor(0.0, dtype=torch.float32))
def forward(self, idx):
return F.embedding(
idx, F.normalize(self.weight, dim=1), self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse
)
def get_norm_weight(self):
return F.normalize(self.weight, dim=1)
class ResConv(nn.Conv2d):
def __init__(self, embed_dim, quant_resi):
ks = 3 if quant_resi < 0 else 1
super().__init__(in_channels=embed_dim, out_channels=embed_dim, kernel_size=ks, stride=1, padding=ks // 2)
self.resi_ratio = abs(quant_resi)
def forward(self, h_BChw):
return h_BChw.mul(1 - self.resi_ratio) + super().forward(h_BChw).mul_(self.resi_ratio)
class VectorQuantizer(nn.Module):
def __init__(
self,
vocab_size: int,
vocab_width: int,
beta: float = 0.25,
use_entropy_loss=False,
entropy_temp=0.01,
):
super().__init__()
self.beta = beta
self.vocab_size = vocab_size
self.vocab_width = vocab_width
self.vocab_usage_record_times: int = 0
self.register_buffer('vocab_usage', torch.zeros(self.vocab_size))
self.codebook = NormalizedEmbedding(self.vocab_size, self.vocab_width)
self.use_entropy_loss = use_entropy_loss
self.inv_entropy_tau = 1 / entropy_temp
def init_vocab(self, eini: float):
if eini > 0:
nn.init.trunc_normal_(self.codebook.weight.data, std=eini)
elif eini < 0:
base = self.vocab_width ** -0.5
base /= 36
self.codebook.weight.data.uniform_(-abs(eini) * base, abs(eini) * base)
def extra_repr(self) -> str:
return f'beta={self.beta:g}'
def forward(self, features):
B, L, C = features.shape
features = features.reshape(-1, C)
features = F.normalize(features, dim=-1).float()
codebook_embed = self.codebook.get_norm_weight()
indices = torch.argmax(features.detach() @ codebook_embed.T, dim=1)
entropy_loss = get_entropy_loss(features, codebook_embed, self.inv_entropy_tau) if self.use_entropy_loss else 0
features_hat = self.codebook(indices)
# calc loss
vq_loss = F.mse_loss(features_hat.detach(), features).mul_(self.beta) + F.mse_loss(features_hat,
features.detach())
features_hat = (features_hat.detach() - features.detach()).add_(features)
# update vocab_usage
prob_per_class_is_chosen = indices.bincount(minlength=self.vocab_size).float()
handler = tdist.all_reduce(prob_per_class_is_chosen, async_op=True) if (
self.training and dist.initialized()) else None
if handler is not None:
handler.wait()
prob_per_class_is_chosen /= prob_per_class_is_chosen.sum()
vocab_usage = (prob_per_class_is_chosen > 0.01 / self.vocab_size).float().mean().mul_(100)
if self.vocab_usage_record_times == 0:
self.vocab_usage.copy_(prob_per_class_is_chosen)
elif self.vocab_usage_record_times < 100:
self.vocab_usage.mul_(0.9).add_(prob_per_class_is_chosen, alpha=0.1)
else:
self.vocab_usage.mul_(0.99).add_(prob_per_class_is_chosen, alpha=0.01)
self.vocab_usage_record_times += 1
return features_hat.view(B, L, C), vq_loss, entropy_loss, vocab_usage
def f_to_idx(self, features):
B, L, C = features.shape
features = features.reshape(-1, C)
features = F.normalize(features, dim=-1).float()
codebook_embed = self.codebook.get_norm_weight().float()
indices = torch.argmax(features.detach() @ codebook_embed.T, dim=1)
return indices.view(B, L)
class VectorQuantizerM(nn.Module):
def __init__(
self,
vocab_size,
vocab_width,
beta=0.25,
use_entropy_loss=False,
entropy_temp=0.01,
num_codebooks=16
):
super().__init__()
self.num_codebooks = num_codebooks
self.codebooks = nn.ModuleList()
for _ in range(num_codebooks):
codebook = VectorQuantizer(
vocab_size=vocab_size // num_codebooks,
vocab_width=vocab_width // num_codebooks,
beta=beta,
use_entropy_loss=use_entropy_loss,
entropy_temp=entropy_temp,
)
self.codebooks.append(codebook)
def init_vocab(self, eini: float):
for codebook in self.codebooks:
codebook.init_vocab(eini)
def f_to_idx(self, features):
indices = []
chunk_size = features.shape[-1] // self.num_codebooks
splited_features = features.split(chunk_size, dim=-1)
for i, codebook in enumerate(self.codebooks):
indices.append(codebook.f_to_idx(splited_features[i]))
indices = torch.stack(indices, dim=1)
return indices
def idx_to_f(self, indices):
assert indices.shape[1] == self.num_codebooks
latent_features = []
for i, codebook in enumerate(self.codebooks):
sub_indices = indices[:, i].flatten(start_dim=1)
latent_feature = codebook.codebook(sub_indices)
latent_features.append(latent_feature)
latent_features = torch.cat(latent_features, dim=-1)
return latent_features
def forward(self, features):
latent_features = []
global_vq_loss = 0.
global_entropy_loss = 0.
global_vocab_usage = 0.
chunk_size = features.shape[-1] // self.num_codebooks
splited_features = features.split(chunk_size, dim=-1)
for i, codebook in enumerate(self.codebooks):
latent_feature, vq_loss, entropy_loss, vocab_usage = codebook(splited_features[i])
latent_features.append(latent_feature)
global_vq_loss += vq_loss
global_entropy_loss += entropy_loss
global_vocab_usage += vocab_usage
latent_features = torch.cat(latent_features, dim=-1)
global_entropy_loss /= self.num_codebooks
global_vq_loss /= self.num_codebooks
global_vocab_usage /= self.num_codebooks
return latent_features, global_vq_loss, global_entropy_loss, global_vocab_usage
|