import torch import torch.nn.functional as F from torch import nn def adopt_weight(weight, global_step, threshold=0, value=0.): if global_step < threshold: weight = value return weight def hinge_d_loss(logits_real, logits_fake): loss_real = torch.mean(F.relu(1. - logits_real)) loss_fake = torch.mean(F.relu(1. + logits_fake)) d_loss = 0.5 * (loss_real + loss_fake) return d_loss def vanilla_d_loss(logits_real, logits_fake): d_loss = 0.5 * ( torch.mean(torch.nn.functional.softplus(-logits_real)) + torch.mean(torch.nn.functional.softplus(logits_fake))) return d_loss def measure_perplexity(predicted_indices, n_embed): # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) avg_probs = encodings.mean(0) perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() cluster_use = torch.sum(avg_probs > 0) return perplexity, cluster_use def l1(x, y): return torch.abs(x - y) def l2(x, y): return torch.pow((x - y), 2) def square_dist_loss(x, y): return torch.sum((x - y) ** 2, dim=1, keepdim=True) def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: nn.init.normal_(m.weight.data, 0.0, 0.02) elif classname.find('BatchNorm') != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0)