Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
import gradio as gr | |
# Define the VAE model | |
class ConvVAE(nn.Module): | |
def __init__(self, input_channels=3, latent_dim=16): | |
super(ConvVAE, self).__init__() | |
self.latent_dim = latent_dim | |
self.enc_conv1 = nn.Conv2d(input_channels, 32, kernel_size=3, stride=2, padding=1) | |
self.enc_conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1) | |
self.enc_conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) | |
self.fc_mu = nn.Linear(5120, latent_dim) | |
self.fc_logvar = nn.Linear(5120, latent_dim) | |
self.fc_decode = nn.Linear(latent_dim, 5120) | |
self.dec_conv1 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=1, padding=1) | |
self.dec_conv2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1) | |
self.dec_conv3 = nn.ConvTranspose2d(32, 3, kernel_size=3, stride=2, padding=1, output_padding=(0,1)) | |
def reparameterize(self, mu, logvar): | |
std = torch.exp(0.5 * logvar) | |
eps = torch.randn_like(std) | |
return mu + eps * std | |
def forward(self, x): | |
x = F.relu(self.enc_conv1(x)) | |
x = F.relu(self.enc_conv2(x)) | |
x = F.relu(self.enc_conv3(x)) | |
x = x.view(x.size(0), -1) | |
mu = self.fc_mu(x) | |
logvar = self.fc_logvar(x) | |
z = self.reparameterize(mu, logvar) | |
return self.decode(z) | |
def decode(self, z): | |
x = F.relu(self.fc_decode(z)) | |
x = x.view(x.size(0), 128, 4, 10) | |
x = F.relu(self.dec_conv1(x)) | |
x = F.relu(self.dec_conv2(x)) | |
x = self.dec_conv3(x) | |
return F.softmax(x, dim=1) | |
# Load model | |
model = ConvVAE() | |
model.load_state_dict(torch.load("vae_supertux.pth", map_location=torch.device("cpu"))) | |
model.eval() | |
def generate_map(seed: int = None): | |
if seed: | |
torch.manual_seed(seed) | |
z = torch.randn(1, model.latent_dim) | |
with torch.no_grad(): | |
output = model.decode(z) # Shape: (1, 3, 15, 40) | |
output = output.squeeze(0).argmax(dim=0) | |
grid = output.cpu().numpy() | |
padded_grid = np.vstack([np.zeros((5, grid.shape[1]), dtype=int), grid]) # Append 5 rows of zeros | |
return ["".join(map(str, row)) for row in padded_grid] # Convert each row to a string | |
gr.Interface( | |
fn=generate_map, | |
inputs=gr.Number(label="Seed"), | |
outputs=gr.JSON(label="Generated Map Grid"), | |
title="VAE Level Generator", | |
description="Returns a 20x40 grid as a list of strings where 0=air, 1=ground, 2=lava" | |
).launch() |