import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.transforms as transforms from torch.utils.data import DataLoader import matplotlib.pyplot as plt import numpy as np import streamlit as st # Define the Generator class Generator(nn.Module): def __init__(self, z_dim, img_dim): super(Generator, self).__init__() self.model = nn.Sequential( nn.Linear(z_dim, 128), nn.ReLU(), nn.Linear(128, 256), nn.ReLU(), nn.Linear(256, 512), nn.ReLU(), nn.Linear(512, img_dim), nn.Tanh() ) def forward(self, x): return self.model(x) # Define the Discriminator class Discriminator(nn.Module): def __init__(self, img_dim): super(Discriminator, self).__init__() self.model = nn.Sequential( nn.Linear(img_dim, 512), nn.LeakyReLU(0.2), nn.Linear(512, 256), nn.LeakyReLU(0.2), nn.Linear(256, 1), nn.Sigmoid() ) def forward(self, x): return self.model(x) # Define training function def train_gan(generator, discriminator, dataloader, n_epochs, z_dim, lr): loss_fn = nn.BCELoss() gen_optimizer = optim.Adam(generator.parameters(), lr=lr) disc_optimizer = optim.Adam(discriminator.parameters(), lr=lr) for epoch in range(n_epochs): for real_imgs, _ in dataloader: real_imgs = real_imgs.view(-1, 784) batch_size = real_imgs.size(0) # Train Discriminator z = torch.randn(batch_size, z_dim) fake_imgs = generator(z) real_labels = torch.ones(batch_size, 1) fake_labels = torch.zeros(batch_size, 1) disc_real_loss = loss_fn(discriminator(real_imgs), real_labels) disc_fake_loss = loss_fn(discriminator(fake_imgs.detach()), fake_labels) disc_loss = disc_real_loss + disc_fake_loss disc_optimizer.zero_grad() disc_loss.backward() disc_optimizer.step() # Train Generator output = discriminator(fake_imgs) gen_loss = loss_fn(output, real_labels) gen_optimizer.zero_grad() gen_loss.backward() gen_optimizer.step() st.write(f'Epoch [{epoch+1}/{n_epochs}], Discriminator Loss: {disc_loss.item()}, Generator Loss: {gen_loss.item()}') # Main Streamlit function def main(): st.title("GAN Image Generator") z_dim = 100 img_dim = 784 # Create GAN models generator = Generator(z_dim, img_dim) discriminator = Discriminator(img_dim) # Set training parameters lr = 0.0002 batch_size = 32 n_epochs = 10 transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) # Load dataset mnist_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True) dataloader = DataLoader(mnist_dataset, batch_size=batch_size, shuffle=True) # Train the GAN train_gan(generator, discriminator, dataloader, n_epochs, z_dim, lr) # Generate images after training st.header("Generated Images") z = torch.randn(10, z_dim) generated_imgs = generator(z).view(-1, 1, 28, 28) fig, axes = plt.subplots(1, 10, figsize=(20, 2)) for i, ax in enumerate(axes): ax.imshow(generated_imgs[i].squeeze().detach().numpy(), cmap='gray') ax.axis('off') st.pyplot(fig) if __name__ == "__main__": main()