Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -6,18 +6,20 @@ import gradio as gr
|
|
6 |
|
7 |
# Define the VAE model
|
8 |
class ConvVAE(nn.Module):
|
9 |
-
def __init__(self, input_channels=3, latent_dim=
|
10 |
-
super(
|
11 |
self.latent_dim = latent_dim
|
12 |
-
|
13 |
-
self.
|
14 |
-
self.
|
15 |
-
self.
|
16 |
-
self.
|
17 |
-
self.
|
18 |
-
|
19 |
-
self.
|
20 |
-
self.
|
|
|
|
|
21 |
|
22 |
def reparameterize(self, mu, logvar):
|
23 |
std = torch.exp(0.5 * logvar)
|
@@ -32,31 +34,44 @@ class ConvVAE(nn.Module):
|
|
32 |
mu = self.fc_mu(x)
|
33 |
logvar = self.fc_logvar(x)
|
34 |
z = self.reparameterize(mu, logvar)
|
35 |
-
|
|
|
36 |
|
37 |
def decode(self, z):
|
38 |
x = F.relu(self.fc_decode(z))
|
39 |
-
x = x.view(x.size(0),
|
40 |
x = F.relu(self.dec_conv1(x))
|
41 |
x = F.relu(self.dec_conv2(x))
|
42 |
x = self.dec_conv3(x)
|
43 |
return F.softmax(x, dim=1)
|
44 |
|
45 |
-
# Load model
|
46 |
model = ConvVAE()
|
47 |
model.load_state_dict(torch.load("vae_supertux.pth", map_location=torch.device("cpu")))
|
48 |
model.eval()
|
49 |
|
50 |
def generate_map(seed: int = None):
|
51 |
-
|
52 |
-
|
53 |
-
|
|
|
|
|
54 |
with torch.no_grad():
|
55 |
-
output = model.decode(z)
|
56 |
-
output = output.
|
57 |
-
grid =
|
58 |
-
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
gr.Interface(
|
62 |
fn=generate_map,
|
|
|
6 |
|
7 |
# Define the VAE model
|
8 |
class ConvVAE(nn.Module):
|
9 |
+
def __init__(self, input_channels=3, latent_dim=32):
|
10 |
+
super(ImprovedConvVAE, self).__init__()
|
11 |
self.latent_dim = latent_dim
|
12 |
+
# Encoder
|
13 |
+
self.enc_conv1 = nn.Conv2d(input_channels, 64, kernel_size=3, stride=2, padding=1)
|
14 |
+
self.enc_conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
|
15 |
+
self.enc_conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
|
16 |
+
self.fc_mu = nn.Linear(256 * 4 * 10, latent_dim)
|
17 |
+
self.fc_logvar = nn.Linear(256 * 4 * 10, latent_dim)
|
18 |
+
# Decoder
|
19 |
+
self.fc_decode = nn.Linear(latent_dim, 256 * 4 * 10)
|
20 |
+
self.dec_conv1 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=1, padding=1)
|
21 |
+
self.dec_conv2 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)
|
22 |
+
self.dec_conv3 = nn.ConvTranspose2d(64, 3, kernel_size=3, stride=2, padding=1, output_padding=(0,1))
|
23 |
|
24 |
def reparameterize(self, mu, logvar):
|
25 |
std = torch.exp(0.5 * logvar)
|
|
|
34 |
mu = self.fc_mu(x)
|
35 |
logvar = self.fc_logvar(x)
|
36 |
z = self.reparameterize(mu, logvar)
|
37 |
+
out = self.decode(z)
|
38 |
+
return out, mu, logvar
|
39 |
|
40 |
def decode(self, z):
|
41 |
x = F.relu(self.fc_decode(z))
|
42 |
+
x = x.view(x.size(0), 256, 4, 10)
|
43 |
x = F.relu(self.dec_conv1(x))
|
44 |
x = F.relu(self.dec_conv2(x))
|
45 |
x = self.dec_conv3(x)
|
46 |
return F.softmax(x, dim=1)
|
47 |
|
48 |
+
# Load trained model
|
49 |
model = ConvVAE()
|
50 |
model.load_state_dict(torch.load("vae_supertux.pth", map_location=torch.device("cpu")))
|
51 |
model.eval()
|
52 |
|
53 |
def generate_map(seed: int = None):
|
54 |
+
model.eval()
|
55 |
+
if seed is None:
|
56 |
+
seed = torch.randint(10000, (1,)).item()
|
57 |
+
torch.manual_seed(seed)
|
58 |
+
z = torch.randn(1, model.latent_dim).to(device)
|
59 |
with torch.no_grad():
|
60 |
+
output = model.decode(z)
|
61 |
+
output = sample_with_temperature(output, temperature=3)[0].cpu().numpy()
|
62 |
+
grid = np.pad(output, ((5, 0), (0, 0)), mode='constant', constant_values=0)
|
63 |
+
|
64 |
+
# Post-processing rule to collapse columns with inner air blocks
|
65 |
+
for j in range(len(grid[0])):
|
66 |
+
non_air_blocks = [grid[i, j] for i in range(len(grid)) if grid[i, j] != 0]
|
67 |
+
k = len(non_air_blocks)
|
68 |
+
if k > 0:
|
69 |
+
grid[20 - k:20, j] = non_air_blocks
|
70 |
+
grid[0:20 - k, j] = 0
|
71 |
+
|
72 |
+
return ["".join(map(str, row)) for row in grid] # Convert each row to a string
|
73 |
+
|
74 |
+
|
75 |
|
76 |
gr.Interface(
|
77 |
fn=generate_map,
|