Spaces:
Running on Zero

Ruurd commited on
Commit
ea86b58
·
verified ·
1 Parent(s): 4cd194e

Remove unnecessary print statements - Add MASK noising

Browse files
Files changed (1) hide show
  1. app.py +1 -2
app.py CHANGED
@@ -115,6 +115,7 @@ def noisify_answer(input_ids, answer_start, threshold=1.0, eot_weight=1.0, clust
115
  noise = rng.choice(np.arange(vocab_size), size=len(noised_indices), p=mixed_probs)
116
  for idx, val in zip(noised_indices, noise):
117
  noised[idx] = val
 
118
 
119
  return noised, noised_indices
120
 
@@ -166,8 +167,6 @@ def generate_diffusion_text(input_ids):
166
  logits = logits.clamp(min=-1e4, max=1e4)
167
  probs = torch.nn.functional.softmax(logits, dim=-1)[0]
168
  probs = torch.clamp(probs, min=1e-8, max=1.0)
169
- print("probs", probs)
170
- print("probs min:", probs.min().item(), "max:", probs.max().item(), "sum:", probs.sum().item())
171
  assert torch.all(torch.isfinite(probs)), "Non-finite values in probs!"
172
  assert (probs >= 0).all(), "Negative probs!"
173
  sampled = torch.multinomial(probs, num_samples=1).squeeze(-1).tolist()
 
115
  noise = rng.choice(np.arange(vocab_size), size=len(noised_indices), p=mixed_probs)
116
  for idx, val in zip(noised_indices, noise):
117
  noised[idx] = val
118
+ noised[idx] = tokenizer.encode('MASK', add_special_tokens = False)
119
 
120
  return noised, noised_indices
121
 
 
167
  logits = logits.clamp(min=-1e4, max=1e4)
168
  probs = torch.nn.functional.softmax(logits, dim=-1)[0]
169
  probs = torch.clamp(probs, min=1e-8, max=1.0)
 
 
170
  assert torch.all(torch.isfinite(probs)), "Non-finite values in probs!"
171
  assert (probs >= 0).all(), "Negative probs!"
172
  sampled = torch.multinomial(probs, num_samples=1).squeeze(-1).tolist()