jamino30 commited on
Commit
0275194
·
verified ·
1 Parent(s): 99fc825

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -22
app.py CHANGED
@@ -6,18 +6,20 @@ import gradio as gr
6
 
7
  # Define the VAE model
8
  class ConvVAE(nn.Module):
9
- def __init__(self, input_channels=3, latent_dim=16):
10
- super(ConvVAE, self).__init__()
11
  self.latent_dim = latent_dim
12
- self.enc_conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, stride=2, padding=1)
13
- self.enc_conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
14
- self.enc_conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
15
- self.fc_mu = nn.Linear(5120, latent_dim)
16
- self.fc_logvar = nn.Linear(5120, latent_dim)
17
- self.fc_decode = nn.Linear(latent_dim, 5120)
18
- self.dec_conv1 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=1, padding=1)
19
- self.dec_conv2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
20
- self.dec_conv3 = nn.ConvTranspose2d(32, 3, kernel_size=3, stride=2, padding=1, output_padding=(0,1))
 
 
21
 
22
  def reparameterize(self, mu, logvar):
23
  std = torch.exp(0.5 * logvar)
@@ -32,31 +34,44 @@ class ConvVAE(nn.Module):
32
  mu = self.fc_mu(x)
33
  logvar = self.fc_logvar(x)
34
  z = self.reparameterize(mu, logvar)
35
- return self.decode(z)
 
36
 
37
  def decode(self, z):
38
  x = F.relu(self.fc_decode(z))
39
- x = x.view(x.size(0), 128, 4, 10)
40
  x = F.relu(self.dec_conv1(x))
41
  x = F.relu(self.dec_conv2(x))
42
  x = self.dec_conv3(x)
43
  return F.softmax(x, dim=1)
44
 
45
- # Load model
46
  model = ConvVAE()
47
  model.load_state_dict(torch.load("vae_supertux.pth", map_location=torch.device("cpu")))
48
  model.eval()
49
 
50
  def generate_map(seed: int = None):
51
- if seed:
52
- torch.manual_seed(seed)
53
- z = torch.randn(1, model.latent_dim)
 
 
54
  with torch.no_grad():
55
- output = model.decode(z) # Shape: (1, 3, 15, 40)
56
- output = output.squeeze(0).argmax(dim=0)
57
- grid = output.cpu().numpy()
58
- padded_grid = np.vstack([np.zeros((5, grid.shape[1]), dtype=int), grid]) # Append 5 rows of zeros
59
- return ["".join(map(str, row)) for row in padded_grid] # Convert each row to a string
 
 
 
 
 
 
 
 
 
 
60
 
61
  gr.Interface(
62
  fn=generate_map,
 
6
 
7
  # Define the VAE model
8
  class ConvVAE(nn.Module):
9
+ def __init__(self, input_channels=3, latent_dim=32):
10
+ super(ImprovedConvVAE, self).__init__()
11
  self.latent_dim = latent_dim
12
+ # Encoder
13
+ self.enc_conv1 = nn.Conv2d(input_channels, 64, kernel_size=3, stride=2, padding=1)
14
+ self.enc_conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
15
+ self.enc_conv3 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
16
+ self.fc_mu = nn.Linear(256 * 4 * 10, latent_dim)
17
+ self.fc_logvar = nn.Linear(256 * 4 * 10, latent_dim)
18
+ # Decoder
19
+ self.fc_decode = nn.Linear(latent_dim, 256 * 4 * 10)
20
+ self.dec_conv1 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=1, padding=1)
21
+ self.dec_conv2 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)
22
+ self.dec_conv3 = nn.ConvTranspose2d(64, 3, kernel_size=3, stride=2, padding=1, output_padding=(0,1))
23
 
24
  def reparameterize(self, mu, logvar):
25
  std = torch.exp(0.5 * logvar)
 
34
  mu = self.fc_mu(x)
35
  logvar = self.fc_logvar(x)
36
  z = self.reparameterize(mu, logvar)
37
+ out = self.decode(z)
38
+ return out, mu, logvar
39
 
40
  def decode(self, z):
41
  x = F.relu(self.fc_decode(z))
42
+ x = x.view(x.size(0), 256, 4, 10)
43
  x = F.relu(self.dec_conv1(x))
44
  x = F.relu(self.dec_conv2(x))
45
  x = self.dec_conv3(x)
46
  return F.softmax(x, dim=1)
47
 
48
+ # Load trained model
49
  model = ConvVAE()
50
  model.load_state_dict(torch.load("vae_supertux.pth", map_location=torch.device("cpu")))
51
  model.eval()
52
 
53
  def generate_map(seed: int = None):
54
+ model.eval()
55
+ if seed is None:
56
+ seed = torch.randint(10000, (1,)).item()
57
+ torch.manual_seed(seed)
58
+ z = torch.randn(1, model.latent_dim).to(device)
59
  with torch.no_grad():
60
+ output = model.decode(z)
61
+ output = sample_with_temperature(output, temperature=3)[0].cpu().numpy()
62
+ grid = np.pad(output, ((5, 0), (0, 0)), mode='constant', constant_values=0)
63
+
64
+ # Post-processing rule to collapse columns with inner air blocks
65
+ for j in range(len(grid[0])):
66
+ non_air_blocks = [grid[i, j] for i in range(len(grid)) if grid[i, j] != 0]
67
+ k = len(non_air_blocks)
68
+ if k > 0:
69
+ grid[20 - k:20, j] = non_air_blocks
70
+ grid[0:20 - k, j] = 0
71
+
72
+ return ["".join(map(str, row)) for row in grid] # Convert each row to a string
73
+
74
+
75
 
76
  gr.Interface(
77
  fn=generate_map,