Spaces:
Running on Zero

Ruurd commited on
Commit
d29da35
·
1 Parent(s): 5976e7b

Fix red highlighting

Browse files
Files changed (1) hide show
  1. app.py +39 -26
app.py CHANGED
@@ -165,37 +165,57 @@ def diffusion_chat(question, eot_weight, max_it, sharpness, noise_clipping, use_
165
  current_tokens, just_noised_indices = noisify_answer(
166
  ori_input_tokens, answer_start, threshold=1.0, eot_weight=eot_weight, clustering=clustering
167
  )
168
- prev_decoded_tokens = []
169
  last_tokens = []
 
170
 
171
  for i in range(max_it):
172
  print('Generating output')
173
  generated_tokens, confidences = generate_diffusion_text(current_tokens, answer_start)
174
  current_tokens = generated_tokens
175
- just_noised_indices = []
 
176
  decoded_ids = current_tokens[answer_start:]
177
  decoded_tokens = tokenizer.convert_ids_to_tokens(decoded_ids)
178
  filtered_tokens = [tok for tok in decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
179
- filtered_prev_tokens = [tok for tok in prev_decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id] if prev_decoded_tokens else []
180
-
181
- if filtered_prev_tokens:
182
- highlighted = []
183
- for i, tok in enumerate(decoded_tokens):
184
- token_str = tokenizer.convert_tokens_to_string([tok])
185
-
186
- abs_idx = answer_start + i
187
- if abs_idx in just_noised_indices:
188
- highlighted.append(f'<span style="color:red">{token_str}</span>')
189
- elif prev_decoded_tokens and i < len(prev_decoded_tokens) and tok != prev_decoded_tokens[i]:
190
- highlighted.append(f'<span style="color:green">{token_str}</span>')
191
- else:
192
- highlighted.append(token_str)
 
 
 
193
  else:
194
- highlighted = [tokenizer.convert_tokens_to_string([tok]) for tok in filtered_tokens]
 
 
195
 
196
- prev_decoded_tokens = decoded_tokens
197
- yield f"<b>Iteration {i+1}/{max_it} (running):</b><br>" + "".join(highlighted).replace('\n', '<br>')
 
 
 
 
 
 
 
 
 
 
198
 
 
 
 
 
199
  last_tokens.append(generated_tokens)
200
  if len(last_tokens) > 3:
201
  last_tokens.pop(0)
@@ -203,13 +223,6 @@ def diffusion_chat(question, eot_weight, max_it, sharpness, noise_clipping, use_
203
  yield f"<b>Stopped early after {i+1} iterations.</b>"
204
  break
205
 
206
- threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
207
- if use_confidence_noising:
208
- current_tokens = confidence_guided_noising(generated_tokens, answer_start, confidences, threshold, eot_weight, noise_clipping)
209
- else:
210
- current_tokens, just_noised_indices = noisify_answer(generated_tokens, answer_start, threshold=threshold, eot_weight=eot_weight, clustering=clustering)
211
-
212
- time.sleep(0.01)
213
 
214
  final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
215
  final_tokens = [tok for tok in final_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
 
165
  current_tokens, just_noised_indices = noisify_answer(
166
  ori_input_tokens, answer_start, threshold=1.0, eot_weight=eot_weight, clustering=clustering
167
  )
 
168
  last_tokens = []
169
+ just_noised_indices = []
170
 
171
  for i in range(max_it):
172
  print('Generating output')
173
  generated_tokens, confidences = generate_diffusion_text(current_tokens, answer_start)
174
  current_tokens = generated_tokens
175
+
176
+ # --- Decode and highlight changed tokens in GREEN ---
177
  decoded_ids = current_tokens[answer_start:]
178
  decoded_tokens = tokenizer.convert_ids_to_tokens(decoded_ids)
179
  filtered_tokens = [tok for tok in decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
180
+
181
+ highlighted = []
182
+ for i, tok in enumerate(decoded_tokens):
183
+ token_str = tokenizer.convert_tokens_to_string([tok])
184
+ if filtered_tokens and i < len(filtered_tokens) and tok != filtered_tokens[i]:
185
+ highlighted.append(f'<span style="color:green">{token_str}</span>')
186
+ else:
187
+ highlighted.append(token_str)
188
+
189
+ yield f"<b>Iteration {i+1}/{max_it} (after generation):</b><br>" + "".join(highlighted).replace('\n', '<br>')
190
+ time.sleep(0.1)
191
+
192
+ # --- Apply noising and highlight RED tokens ---
193
+ threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
194
+ if use_confidence_noising:
195
+ current_tokens = confidence_guided_noising(generated_tokens, answer_start, confidences, threshold, eot_weight, noise_clipping)
196
+ just_noised_indices = [] # optional: could track confidence-weighted indices too
197
  else:
198
+ current_tokens, just_noised_indices = noisify_answer(
199
+ generated_tokens, answer_start, threshold=threshold, eot_weight=eot_weight, clustering=clustering
200
+ )
201
 
202
+ decoded_ids = current_tokens[answer_start:]
203
+ decoded_tokens = tokenizer.convert_ids_to_tokens(decoded_ids)
204
+ filtered_tokens = [tok for tok in decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
205
+
206
+ highlighted = []
207
+ for i, tok in enumerate(filtered_tokens):
208
+ token_str = tokenizer.convert_tokens_to_string([tok])
209
+ abs_idx = answer_start + i
210
+ if abs_idx in just_noised_indices:
211
+ highlighted.append(f'<span style="color:red">{token_str}</span>')
212
+ else:
213
+ highlighted.append(token_str)
214
 
215
+ yield f"<b>Iteration {i+1}/{max_it} (after noising):</b><br>" + "".join(highlighted).replace('\n', '<br>')
216
+ time.sleep(0.1)
217
+
218
+ # --- Early stopping ---
219
  last_tokens.append(generated_tokens)
220
  if len(last_tokens) > 3:
221
  last_tokens.pop(0)
 
223
  yield f"<b>Stopped early after {i+1} iterations.</b>"
224
  break
225
 
 
 
 
 
 
 
 
226
 
227
  final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
228
  final_tokens = [tok for tok in final_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]