File size: 3,578 Bytes
674763c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import streamlit as st

# Define the Generator
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(z_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, img_dim),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

# Define the Discriminator
class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(img_dim, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

# Define training function
def train_gan(generator, discriminator, dataloader, n_epochs, z_dim, lr):
    loss_fn = nn.BCELoss()
    gen_optimizer = optim.Adam(generator.parameters(), lr=lr)
    disc_optimizer = optim.Adam(discriminator.parameters(), lr=lr)

    for epoch in range(n_epochs):
        for real_imgs, _ in dataloader:
            real_imgs = real_imgs.view(-1, 784)
            batch_size = real_imgs.size(0)

            # Train Discriminator
            z = torch.randn(batch_size, z_dim)
            fake_imgs = generator(z)

            real_labels = torch.ones(batch_size, 1)
            fake_labels = torch.zeros(batch_size, 1)

            disc_real_loss = loss_fn(discriminator(real_imgs), real_labels)
            disc_fake_loss = loss_fn(discriminator(fake_imgs.detach()), fake_labels)
            disc_loss = disc_real_loss + disc_fake_loss

            disc_optimizer.zero_grad()
            disc_loss.backward()
            disc_optimizer.step()

            # Train Generator
            output = discriminator(fake_imgs)
            gen_loss = loss_fn(output, real_labels)

            gen_optimizer.zero_grad()
            gen_loss.backward()
            gen_optimizer.step()

        st.write(f'Epoch [{epoch+1}/{n_epochs}], Discriminator Loss: {disc_loss.item()}, Generator Loss: {gen_loss.item()}')

# Main Streamlit function
def main():
    st.title("GAN Image Generator")

    z_dim = 100
    img_dim = 784

    # Create GAN models
    generator = Generator(z_dim, img_dim)
    discriminator = Discriminator(img_dim)

    # Set training parameters
    lr = 0.0002
    batch_size = 32
    n_epochs = 10

    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

    # Load dataset
    mnist_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
    dataloader = DataLoader(mnist_dataset, batch_size=batch_size, shuffle=True)

    # Train the GAN
    train_gan(generator, discriminator, dataloader, n_epochs, z_dim, lr)

    # Generate images after training
    st.header("Generated Images")
    z = torch.randn(10, z_dim)
    generated_imgs = generator(z).view(-1, 1, 28, 28)

    fig, axes = plt.subplots(1, 10, figsize=(20, 2))
    for i, ax in enumerate(axes):
        ax.imshow(generated_imgs[i].squeeze().detach().numpy(), cmap='gray')
        ax.axis('off')
    st.pyplot(fig)

if __name__ == "__main__":
    main()