|
import torch |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
import torchvision |
|
import torchvision.transforms as transforms |
|
from torch.utils.data import DataLoader |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import streamlit as st |
|
|
|
|
|
class Generator(nn.Module): |
|
def __init__(self, z_dim, img_dim): |
|
super(Generator, self).__init__() |
|
self.model = nn.Sequential( |
|
nn.Linear(z_dim, 128), |
|
nn.ReLU(), |
|
nn.Linear(128, 256), |
|
nn.ReLU(), |
|
nn.Linear(256, 512), |
|
nn.ReLU(), |
|
nn.Linear(512, img_dim), |
|
nn.Tanh() |
|
) |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
|
|
class Discriminator(nn.Module): |
|
def __init__(self, img_dim): |
|
super(Discriminator, self).__init__() |
|
self.model = nn.Sequential( |
|
nn.Linear(img_dim, 512), |
|
nn.LeakyReLU(0.2), |
|
nn.Linear(512, 256), |
|
nn.LeakyReLU(0.2), |
|
nn.Linear(256, 1), |
|
nn.Sigmoid() |
|
) |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
|
|
def train_gan(generator, discriminator, dataloader, n_epochs, z_dim, lr): |
|
loss_fn = nn.BCELoss() |
|
gen_optimizer = optim.Adam(generator.parameters(), lr=lr) |
|
disc_optimizer = optim.Adam(discriminator.parameters(), lr=lr) |
|
|
|
for epoch in range(n_epochs): |
|
for real_imgs, _ in dataloader: |
|
real_imgs = real_imgs.view(-1, 784) |
|
batch_size = real_imgs.size(0) |
|
|
|
|
|
z = torch.randn(batch_size, z_dim) |
|
fake_imgs = generator(z) |
|
|
|
real_labels = torch.ones(batch_size, 1) |
|
fake_labels = torch.zeros(batch_size, 1) |
|
|
|
disc_real_loss = loss_fn(discriminator(real_imgs), real_labels) |
|
disc_fake_loss = loss_fn(discriminator(fake_imgs.detach()), fake_labels) |
|
disc_loss = disc_real_loss + disc_fake_loss |
|
|
|
disc_optimizer.zero_grad() |
|
disc_loss.backward() |
|
disc_optimizer.step() |
|
|
|
|
|
output = discriminator(fake_imgs) |
|
gen_loss = loss_fn(output, real_labels) |
|
|
|
gen_optimizer.zero_grad() |
|
gen_loss.backward() |
|
gen_optimizer.step() |
|
|
|
st.write(f'Epoch [{epoch+1}/{n_epochs}], Discriminator Loss: {disc_loss.item()}, Generator Loss: {gen_loss.item()}') |
|
|
|
|
|
def main(): |
|
st.title("GAN Image Generator") |
|
|
|
z_dim = 100 |
|
img_dim = 784 |
|
|
|
|
|
generator = Generator(z_dim, img_dim) |
|
discriminator = Discriminator(img_dim) |
|
|
|
|
|
lr = 0.0002 |
|
batch_size = 32 |
|
n_epochs = 10 |
|
|
|
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) |
|
|
|
|
|
mnist_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True) |
|
dataloader = DataLoader(mnist_dataset, batch_size=batch_size, shuffle=True) |
|
|
|
|
|
train_gan(generator, discriminator, dataloader, n_epochs, z_dim, lr) |
|
|
|
|
|
st.header("Generated Images") |
|
z = torch.randn(10, z_dim) |
|
generated_imgs = generator(z).view(-1, 1, 28, 28) |
|
|
|
fig, axes = plt.subplots(1, 10, figsize=(20, 2)) |
|
for i, ax in enumerate(axes): |
|
ax.imshow(generated_imgs[i].squeeze().detach().numpy(), cmap='gray') |
|
ax.axis('off') |
|
st.pyplot(fig) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|