File size: 3,578 Bytes
674763c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import streamlit as st
# Define the Generator
class Generator(nn.Module):
def __init__(self, z_dim, img_dim):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(z_dim, 128),
nn.ReLU(),
nn.Linear(128, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, img_dim),
nn.Tanh()
)
def forward(self, x):
return self.model(x)
# Define the Discriminator
class Discriminator(nn.Module):
def __init__(self, img_dim):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(img_dim, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.model(x)
# Define training function
def train_gan(generator, discriminator, dataloader, n_epochs, z_dim, lr):
loss_fn = nn.BCELoss()
gen_optimizer = optim.Adam(generator.parameters(), lr=lr)
disc_optimizer = optim.Adam(discriminator.parameters(), lr=lr)
for epoch in range(n_epochs):
for real_imgs, _ in dataloader:
real_imgs = real_imgs.view(-1, 784)
batch_size = real_imgs.size(0)
# Train Discriminator
z = torch.randn(batch_size, z_dim)
fake_imgs = generator(z)
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)
disc_real_loss = loss_fn(discriminator(real_imgs), real_labels)
disc_fake_loss = loss_fn(discriminator(fake_imgs.detach()), fake_labels)
disc_loss = disc_real_loss + disc_fake_loss
disc_optimizer.zero_grad()
disc_loss.backward()
disc_optimizer.step()
# Train Generator
output = discriminator(fake_imgs)
gen_loss = loss_fn(output, real_labels)
gen_optimizer.zero_grad()
gen_loss.backward()
gen_optimizer.step()
st.write(f'Epoch [{epoch+1}/{n_epochs}], Discriminator Loss: {disc_loss.item()}, Generator Loss: {gen_loss.item()}')
# Main Streamlit function
def main():
st.title("GAN Image Generator")
z_dim = 100
img_dim = 784
# Create GAN models
generator = Generator(z_dim, img_dim)
discriminator = Discriminator(img_dim)
# Set training parameters
lr = 0.0002
batch_size = 32
n_epochs = 10
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
# Load dataset
mnist_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(mnist_dataset, batch_size=batch_size, shuffle=True)
# Train the GAN
train_gan(generator, discriminator, dataloader, n_epochs, z_dim, lr)
# Generate images after training
st.header("Generated Images")
z = torch.randn(10, z_dim)
generated_imgs = generator(z).view(-1, 1, 28, 28)
fig, axes = plt.subplots(1, 10, figsize=(20, 2))
for i, ax in enumerate(axes):
ax.imshow(generated_imgs[i].squeeze().detach().numpy(), cmap='gray')
ax.axis('off')
st.pyplot(fig)
if __name__ == "__main__":
main()
|