Ruurd commited on
Commit
df366f4
·
1 Parent(s): 13b1370

Add insert noise: number of inserts

Browse files
Files changed (1) hide show
  1. app.py +6 -7
app.py CHANGED
@@ -56,14 +56,13 @@ def get_noising_schedule(i, max_it, sharpness=5.0):
56
  x = i / max_it
57
  return (np.exp(-sharpness * x) - np.exp(-sharpness)) / (1 - np.exp(-sharpness))
58
 
59
- def noisify_answer(input_ids, answer_start, threshold=1.0, eot_weight=1.0, insert_prob=0.05):
60
  noised = input_ids.copy()
61
  answer_len = len(noised) - answer_start
62
  num_to_noise = int(threshold * answer_len)
63
 
64
- # Randomly insert one token with probability `insert_prob`
65
- if rng.random() < insert_prob:
66
- insert_idx = rng.integers(answer_start + 1, len(noised)) # Avoid inserting at the very start
67
  insert_token = rng.choice(np.arange(vocab_size), p=token_probabilities)
68
  noised = np.concatenate([noised[:insert_idx], [insert_token], noised[insert_idx:]])
69
  noised = noised[:len(input_ids)]
@@ -138,7 +137,7 @@ def generate_diffusion_text(input_ids, answer_start):
138
  return sampled, conf
139
 
140
  # --- Inference Wrapper ---
141
- def diffusion_chat(question, eot_weight, max_it, sharpness, noise_clipping, use_confidence_noising, insert_prob):
142
  placeholder = "What do you know about the city of New York?"
143
  if question.strip() == "":
144
  question = placeholder
@@ -195,7 +194,7 @@ def diffusion_chat(question, eot_weight, max_it, sharpness, noise_clipping, use_
195
  if use_confidence_noising:
196
  current_tokens = confidence_guided_noising(generated_tokens, answer_start, confidences, threshold, eot_weight, noise_clipping)
197
  else:
198
- current_tokens = noisify_answer(generated_tokens, answer_start, threshold=threshold, eot_weight=eot_weight, insert_prob=insert_prob)
199
 
200
  time.sleep(0.01)
201
 
@@ -219,7 +218,7 @@ demo = gr.Interface(
219
  gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="↓ = more noising (sharpness)"),
220
  gr.Slider(0.01, 1.0, value=0.05, step=0.01, label="↓ = more confidence guidance (noise clipping)"),
221
  gr.Checkbox(value=False, label="Use confidence-guided noising"),
222
- gr.Slider(0.0, 1.0, value=0.05, step=0.01, label="Chance to insert token (↓ = less structural change)")
223
  ],
224
  outputs=[gr.HTML(label="Diffusion Output")],
225
  title="Diffusion Language Model Chat",
 
56
  x = i / max_it
57
  return (np.exp(-sharpness * x) - np.exp(-sharpness)) / (1 - np.exp(-sharpness))
58
 
59
+ def noisify_answer(input_ids, answer_start, threshold=1.0, eot_weight=1.0, num_inserts=0):
60
  noised = input_ids.copy()
61
  answer_len = len(noised) - answer_start
62
  num_to_noise = int(threshold * answer_len)
63
 
64
+ for _ in range(num_inserts):
65
+ insert_idx = rng.integers(answer_start + 1, len(noised))
 
66
  insert_token = rng.choice(np.arange(vocab_size), p=token_probabilities)
67
  noised = np.concatenate([noised[:insert_idx], [insert_token], noised[insert_idx:]])
68
  noised = noised[:len(input_ids)]
 
137
  return sampled, conf
138
 
139
  # --- Inference Wrapper ---
140
+ def diffusion_chat(question, eot_weight, max_it, sharpness, noise_clipping, use_confidence_noising, num_inserts):
141
  placeholder = "What do you know about the city of New York?"
142
  if question.strip() == "":
143
  question = placeholder
 
194
  if use_confidence_noising:
195
  current_tokens = confidence_guided_noising(generated_tokens, answer_start, confidences, threshold, eot_weight, noise_clipping)
196
  else:
197
+ current_tokens = noisify_answer(generated_tokens, answer_start, threshold=threshold, eot_weight=eot_weight, num_inserts=num_inserts)
198
 
199
  time.sleep(0.01)
200
 
 
218
  gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="↓ = more noising (sharpness)"),
219
  gr.Slider(0.01, 1.0, value=0.05, step=0.01, label="↓ = more confidence guidance (noise clipping)"),
220
  gr.Checkbox(value=False, label="Use confidence-guided noising"),
221
+ gr.Slider(0, 100, value=0, step=1, label="Number of tokens to insert randomly (↓ = less structural change)")
222
  ],
223
  outputs=[gr.HTML(label="Diffusion Output")],
224
  title="Diffusion Language Model Chat",