import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel, PretrainedConfig class BaseVAE(nn.Module): def __init__(self, latent_dim=16): super(BaseVAE, self).__init__() self.latent_dim = latent_dim self.encoder = nn.Sequential( nn.Conv2d(3, 32, 4, 2, 1), # 32x32 -> 16x16 nn.BatchNorm2d(32), nn.ReLU(), nn.Conv2d(32, 64, 4, 2, 1), # 16x16 -> 8x8 nn.BatchNorm2d(64), nn.ReLU(), nn.Conv2d(64, 128, 4, 2, 1), # 8x8 -> 4x4 nn.BatchNorm2d(128), nn.ReLU(), nn.Flatten() ) self.fc_mu = nn.Linear(128 * 4 * 4, latent_dim) self.fc_logvar = nn.Linear(128 * 4 * 4, latent_dim) self.decoder_input = nn.Linear(latent_dim, 128 * 4 * 4) self.decoder = nn.Sequential( nn.ConvTranspose2d(128, 64, 4, 2, 1), # 4x4 -> 8x8 nn.BatchNorm2d(64), nn.ReLU(), nn.ConvTranspose2d(64, 32, 4, 2, 1), # 8x8 -> 16x16 nn.BatchNorm2d(32), nn.ReLU(), nn.ConvTranspose2d(32, 3, 4, 2, 1), # 16x16 -> 32x32 nn.Sigmoid() ) def encode(self, x): x = self.encoder(x) mu = self.fc_mu(x) logvar = self.fc_logvar(x) return mu, logvar def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def decode(self, z): x = self.decoder_input(z) x = x.view(-1, 128, 4, 4) return self.decoder(x) def forward(self, x): mu, logvar = self.encode(x) z = self.reparameterize(mu, logvar) recon = self.decode(z) return recon, mu, logvar class VAEConfig(PretrainedConfig): model_type = "vae" def __init__(self, latent_dim=16, **kwargs): super().__init__(**kwargs) self.latent_dim = latent_dim class VAEModel(PreTrainedModel): config_class = VAEConfig def __init__(self, config): super().__init__(config) self.vae = BaseVAE(latent_dim=config.latent_dim) self.post_init() def forward(self, x): return self.vae(x) def encode(self, x): return self.vae.encode(x) def decode(self, z): return self.vae.decode(z) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = VAEModel.from_pretrained("BioMike/emoji-vae-init").to(device) model.eval()