jamino30 commited on
Commit
fae0652
·
verified ·
1 Parent(s): 019e9ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -0
app.py CHANGED
@@ -50,6 +50,16 @@ model = ConvVAE()
50
  model.load_state_dict(torch.load("vae_supertux.pth", map_location=torch.device("cpu")))
51
  model.eval()
52
 
 
 
 
 
 
 
 
 
 
 
53
  def generate_map(seed: int = None):
54
  model.eval()
55
  if seed is None:
 
50
  model.load_state_dict(torch.load("vae_supertux.pth", map_location=torch.device("cpu")))
51
  model.eval()
52
 
53
+ # Sampling
54
+ def sample_with_temperature(probs, temperature=1.2):
55
+ logits = torch.log(probs + 1e-8) / temperature
56
+ scaled_probs = torch.softmax(logits, dim=1)
57
+ batch, channels, height, width = scaled_probs.shape
58
+ scaled_probs = scaled_probs.permute(0, 2, 3, 1).contiguous().view(-1, channels)
59
+ sampled = torch.multinomial(scaled_probs, num_samples=1)
60
+ sampled = sampled.view(batch, height, width)
61
+ return sampled
62
+
63
  def generate_map(seed: int = None):
64
  model.eval()
65
  if seed is None: