Spaces:
Running
on
Zero
Running
on
Zero
Fix red highlighting
Browse files
app.py
CHANGED
@@ -165,37 +165,57 @@ def diffusion_chat(question, eot_weight, max_it, sharpness, noise_clipping, use_
|
|
165 |
current_tokens, just_noised_indices = noisify_answer(
|
166 |
ori_input_tokens, answer_start, threshold=1.0, eot_weight=eot_weight, clustering=clustering
|
167 |
)
|
168 |
-
prev_decoded_tokens = []
|
169 |
last_tokens = []
|
|
|
170 |
|
171 |
for i in range(max_it):
|
172 |
print('Generating output')
|
173 |
generated_tokens, confidences = generate_diffusion_text(current_tokens, answer_start)
|
174 |
current_tokens = generated_tokens
|
175 |
-
|
|
|
176 |
decoded_ids = current_tokens[answer_start:]
|
177 |
decoded_tokens = tokenizer.convert_ids_to_tokens(decoded_ids)
|
178 |
filtered_tokens = [tok for tok in decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
|
|
|
|
|
|
193 |
else:
|
194 |
-
|
|
|
|
|
195 |
|
196 |
-
|
197 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
|
|
|
|
|
|
|
|
|
199 |
last_tokens.append(generated_tokens)
|
200 |
if len(last_tokens) > 3:
|
201 |
last_tokens.pop(0)
|
@@ -203,13 +223,6 @@ def diffusion_chat(question, eot_weight, max_it, sharpness, noise_clipping, use_
|
|
203 |
yield f"<b>Stopped early after {i+1} iterations.</b>"
|
204 |
break
|
205 |
|
206 |
-
threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
|
207 |
-
if use_confidence_noising:
|
208 |
-
current_tokens = confidence_guided_noising(generated_tokens, answer_start, confidences, threshold, eot_weight, noise_clipping)
|
209 |
-
else:
|
210 |
-
current_tokens, just_noised_indices = noisify_answer(generated_tokens, answer_start, threshold=threshold, eot_weight=eot_weight, clustering=clustering)
|
211 |
-
|
212 |
-
time.sleep(0.01)
|
213 |
|
214 |
final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
|
215 |
final_tokens = [tok for tok in final_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
|
|
|
165 |
current_tokens, just_noised_indices = noisify_answer(
|
166 |
ori_input_tokens, answer_start, threshold=1.0, eot_weight=eot_weight, clustering=clustering
|
167 |
)
|
|
|
168 |
last_tokens = []
|
169 |
+
just_noised_indices = []
|
170 |
|
171 |
for i in range(max_it):
|
172 |
print('Generating output')
|
173 |
generated_tokens, confidences = generate_diffusion_text(current_tokens, answer_start)
|
174 |
current_tokens = generated_tokens
|
175 |
+
|
176 |
+
# --- Decode and highlight changed tokens in GREEN ---
|
177 |
decoded_ids = current_tokens[answer_start:]
|
178 |
decoded_tokens = tokenizer.convert_ids_to_tokens(decoded_ids)
|
179 |
filtered_tokens = [tok for tok in decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
|
180 |
+
|
181 |
+
highlighted = []
|
182 |
+
for i, tok in enumerate(decoded_tokens):
|
183 |
+
token_str = tokenizer.convert_tokens_to_string([tok])
|
184 |
+
if filtered_tokens and i < len(filtered_tokens) and tok != filtered_tokens[i]:
|
185 |
+
highlighted.append(f'<span style="color:green">{token_str}</span>')
|
186 |
+
else:
|
187 |
+
highlighted.append(token_str)
|
188 |
+
|
189 |
+
yield f"<b>Iteration {i+1}/{max_it} (after generation):</b><br>" + "".join(highlighted).replace('\n', '<br>')
|
190 |
+
time.sleep(0.1)
|
191 |
+
|
192 |
+
# --- Apply noising and highlight RED tokens ---
|
193 |
+
threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
|
194 |
+
if use_confidence_noising:
|
195 |
+
current_tokens = confidence_guided_noising(generated_tokens, answer_start, confidences, threshold, eot_weight, noise_clipping)
|
196 |
+
just_noised_indices = [] # optional: could track confidence-weighted indices too
|
197 |
else:
|
198 |
+
current_tokens, just_noised_indices = noisify_answer(
|
199 |
+
generated_tokens, answer_start, threshold=threshold, eot_weight=eot_weight, clustering=clustering
|
200 |
+
)
|
201 |
|
202 |
+
decoded_ids = current_tokens[answer_start:]
|
203 |
+
decoded_tokens = tokenizer.convert_ids_to_tokens(decoded_ids)
|
204 |
+
filtered_tokens = [tok for tok in decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
|
205 |
+
|
206 |
+
highlighted = []
|
207 |
+
for i, tok in enumerate(filtered_tokens):
|
208 |
+
token_str = tokenizer.convert_tokens_to_string([tok])
|
209 |
+
abs_idx = answer_start + i
|
210 |
+
if abs_idx in just_noised_indices:
|
211 |
+
highlighted.append(f'<span style="color:red">{token_str}</span>')
|
212 |
+
else:
|
213 |
+
highlighted.append(token_str)
|
214 |
|
215 |
+
yield f"<b>Iteration {i+1}/{max_it} (after noising):</b><br>" + "".join(highlighted).replace('\n', '<br>')
|
216 |
+
time.sleep(0.1)
|
217 |
+
|
218 |
+
# --- Early stopping ---
|
219 |
last_tokens.append(generated_tokens)
|
220 |
if len(last_tokens) > 3:
|
221 |
last_tokens.pop(0)
|
|
|
223 |
yield f"<b>Stopped early after {i+1} iterations.</b>"
|
224 |
break
|
225 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
|
227 |
final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
|
228 |
final_tokens = [tok for tok in final_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
|