File size: 2,635 Bytes
2c480a0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel, PretrainedConfig
class BaseVAE(nn.Module):
def __init__(self, latent_dim=16):
super(BaseVAE, self).__init__()
self.latent_dim = latent_dim
self.encoder = nn.Sequential(
nn.Conv2d(3, 32, 4, 2, 1), # 32x32 -> 16x16
nn.BatchNorm2d(32),
nn.ReLU(),
nn.Conv2d(32, 64, 4, 2, 1), # 16x16 -> 8x8
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 128, 4, 2, 1), # 8x8 -> 4x4
nn.BatchNorm2d(128),
nn.ReLU(),
nn.Flatten()
)
self.fc_mu = nn.Linear(128 * 4 * 4, latent_dim)
self.fc_logvar = nn.Linear(128 * 4 * 4, latent_dim)
self.decoder_input = nn.Linear(latent_dim, 128 * 4 * 4)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(128, 64, 4, 2, 1), # 4x4 -> 8x8
nn.BatchNorm2d(64),
nn.ReLU(),
nn.ConvTranspose2d(64, 32, 4, 2, 1), # 8x8 -> 16x16
nn.BatchNorm2d(32),
nn.ReLU(),
nn.ConvTranspose2d(32, 3, 4, 2, 1), # 16x16 -> 32x32
nn.Sigmoid()
)
def encode(self, x):
x = self.encoder(x)
mu = self.fc_mu(x)
logvar = self.fc_logvar(x)
return mu, logvar
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
x = self.decoder_input(z)
x = x.view(-1, 128, 4, 4)
return self.decoder(x)
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
recon = self.decode(z)
return recon, mu, logvar
class VAEConfig(PretrainedConfig):
model_type = "vae"
def __init__(self, latent_dim=16, **kwargs):
super().__init__(**kwargs)
self.latent_dim = latent_dim
class VAEModel(PreTrainedModel):
config_class = VAEConfig
def __init__(self, config):
super().__init__(config)
self.vae = BaseVAE(latent_dim=config.latent_dim)
self.post_init()
def forward(self, x):
return self.vae(x)
def encode(self, x):
return self.vae.encode(x)
def decode(self, z):
return self.vae.decode(z)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VAEModel.from_pretrained("BioMike/emoji-vae-init").to(device)
model.eval()
|