Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -50,6 +50,16 @@ model = ConvVAE()
|
|
50 |
model.load_state_dict(torch.load("vae_supertux.pth", map_location=torch.device("cpu")))
|
51 |
model.eval()
|
52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
def generate_map(seed: int = None):
|
54 |
model.eval()
|
55 |
if seed is None:
|
|
|
50 |
model.load_state_dict(torch.load("vae_supertux.pth", map_location=torch.device("cpu")))
|
51 |
model.eval()
|
52 |
|
53 |
+
# Sampling
|
54 |
+
def sample_with_temperature(probs, temperature=1.2):
|
55 |
+
logits = torch.log(probs + 1e-8) / temperature
|
56 |
+
scaled_probs = torch.softmax(logits, dim=1)
|
57 |
+
batch, channels, height, width = scaled_probs.shape
|
58 |
+
scaled_probs = scaled_probs.permute(0, 2, 3, 1).contiguous().view(-1, channels)
|
59 |
+
sampled = torch.multinomial(scaled_probs, num_samples=1)
|
60 |
+
sampled = sampled.view(batch, height, width)
|
61 |
+
return sampled
|
62 |
+
|
63 |
def generate_map(seed: int = None):
|
64 |
model.eval()
|
65 |
if seed is None:
|