Spaces:
Running
on
Zero
Running
on
Zero
Remove unnecessary print statements - Add MASK noising
Browse files
app.py
CHANGED
@@ -115,6 +115,7 @@ def noisify_answer(input_ids, answer_start, threshold=1.0, eot_weight=1.0, clust
|
|
115 |
noise = rng.choice(np.arange(vocab_size), size=len(noised_indices), p=mixed_probs)
|
116 |
for idx, val in zip(noised_indices, noise):
|
117 |
noised[idx] = val
|
|
|
118 |
|
119 |
return noised, noised_indices
|
120 |
|
@@ -166,8 +167,6 @@ def generate_diffusion_text(input_ids):
|
|
166 |
logits = logits.clamp(min=-1e4, max=1e4)
|
167 |
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
|
168 |
probs = torch.clamp(probs, min=1e-8, max=1.0)
|
169 |
-
print("probs", probs)
|
170 |
-
print("probs min:", probs.min().item(), "max:", probs.max().item(), "sum:", probs.sum().item())
|
171 |
assert torch.all(torch.isfinite(probs)), "Non-finite values in probs!"
|
172 |
assert (probs >= 0).all(), "Negative probs!"
|
173 |
sampled = torch.multinomial(probs, num_samples=1).squeeze(-1).tolist()
|
|
|
115 |
noise = rng.choice(np.arange(vocab_size), size=len(noised_indices), p=mixed_probs)
|
116 |
for idx, val in zip(noised_indices, noise):
|
117 |
noised[idx] = val
|
118 |
+
noised[idx] = tokenizer.encode('MASK', add_special_tokens = False)
|
119 |
|
120 |
return noised, noised_indices
|
121 |
|
|
|
167 |
logits = logits.clamp(min=-1e4, max=1e4)
|
168 |
probs = torch.nn.functional.softmax(logits, dim=-1)[0]
|
169 |
probs = torch.clamp(probs, min=1e-8, max=1.0)
|
|
|
|
|
170 |
assert torch.all(torch.isfinite(probs)), "Non-finite values in probs!"
|
171 |
assert (probs >= 0).all(), "Negative probs!"
|
172 |
sampled = torch.multinomial(probs, num_samples=1).squeeze(-1).tolist()
|