GAN_v1 / app.py
Hem345's picture
Create app.py
674763c verified
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import streamlit as st
# Define the Generator
class Generator(nn.Module):
def __init__(self, z_dim, img_dim):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(z_dim, 128),
nn.ReLU(),
nn.Linear(128, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, img_dim),
nn.Tanh()
)
def forward(self, x):
return self.model(x)
# Define the Discriminator
class Discriminator(nn.Module):
def __init__(self, img_dim):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(img_dim, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.model(x)
# Define training function
def train_gan(generator, discriminator, dataloader, n_epochs, z_dim, lr):
loss_fn = nn.BCELoss()
gen_optimizer = optim.Adam(generator.parameters(), lr=lr)
disc_optimizer = optim.Adam(discriminator.parameters(), lr=lr)
for epoch in range(n_epochs):
for real_imgs, _ in dataloader:
real_imgs = real_imgs.view(-1, 784)
batch_size = real_imgs.size(0)
# Train Discriminator
z = torch.randn(batch_size, z_dim)
fake_imgs = generator(z)
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)
disc_real_loss = loss_fn(discriminator(real_imgs), real_labels)
disc_fake_loss = loss_fn(discriminator(fake_imgs.detach()), fake_labels)
disc_loss = disc_real_loss + disc_fake_loss
disc_optimizer.zero_grad()
disc_loss.backward()
disc_optimizer.step()
# Train Generator
output = discriminator(fake_imgs)
gen_loss = loss_fn(output, real_labels)
gen_optimizer.zero_grad()
gen_loss.backward()
gen_optimizer.step()
st.write(f'Epoch [{epoch+1}/{n_epochs}], Discriminator Loss: {disc_loss.item()}, Generator Loss: {gen_loss.item()}')
# Main Streamlit function
def main():
st.title("GAN Image Generator")
z_dim = 100
img_dim = 784
# Create GAN models
generator = Generator(z_dim, img_dim)
discriminator = Discriminator(img_dim)
# Set training parameters
lr = 0.0002
batch_size = 32
n_epochs = 10
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
# Load dataset
mnist_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(mnist_dataset, batch_size=batch_size, shuffle=True)
# Train the GAN
train_gan(generator, discriminator, dataloader, n_epochs, z_dim, lr)
# Generate images after training
st.header("Generated Images")
z = torch.randn(10, z_dim)
generated_imgs = generator(z).view(-1, 1, 28, 28)
fig, axes = plt.subplots(1, 10, figsize=(20, 2))
for i, ax in enumerate(axes):
ax.imshow(generated_imgs[i].squeeze().detach().numpy(), cmap='gray')
ax.axis('off')
st.pyplot(fig)
if __name__ == "__main__":
main()