Ruurd commited on
Commit
9756472
·
1 Parent(s): d29da35

Fix red higlighting

Browse files
Files changed (1) hide show
  1. app.py +15 -9
app.py CHANGED
@@ -166,7 +166,7 @@ def diffusion_chat(question, eot_weight, max_it, sharpness, noise_clipping, use_
166
  ori_input_tokens, answer_start, threshold=1.0, eot_weight=eot_weight, clustering=clustering
167
  )
168
  last_tokens = []
169
- just_noised_indices = []
170
 
171
  for i in range(max_it):
172
  print('Generating output')
@@ -176,24 +176,26 @@ def diffusion_chat(question, eot_weight, max_it, sharpness, noise_clipping, use_
176
  # --- Decode and highlight changed tokens in GREEN ---
177
  decoded_ids = current_tokens[answer_start:]
178
  decoded_tokens = tokenizer.convert_ids_to_tokens(decoded_ids)
179
- filtered_tokens = [tok for tok in decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
180
 
181
  highlighted = []
182
- for i, tok in enumerate(decoded_tokens):
183
  token_str = tokenizer.convert_tokens_to_string([tok])
184
- if filtered_tokens and i < len(filtered_tokens) and tok != filtered_tokens[i]:
185
  highlighted.append(f'<span style="color:green">{token_str}</span>')
186
  else:
187
  highlighted.append(token_str)
188
 
 
189
  yield f"<b>Iteration {i+1}/{max_it} (after generation):</b><br>" + "".join(highlighted).replace('\n', '<br>')
190
  time.sleep(0.1)
191
 
192
  # --- Apply noising and highlight RED tokens ---
193
  threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
194
  if use_confidence_noising:
195
- current_tokens = confidence_guided_noising(generated_tokens, answer_start, confidences, threshold, eot_weight, noise_clipping)
196
- just_noised_indices = [] # optional: could track confidence-weighted indices too
 
 
197
  else:
198
  current_tokens, just_noised_indices = noisify_answer(
199
  generated_tokens, answer_start, threshold=threshold, eot_weight=eot_weight, clustering=clustering
@@ -201,12 +203,15 @@ def diffusion_chat(question, eot_weight, max_it, sharpness, noise_clipping, use_
201
 
202
  decoded_ids = current_tokens[answer_start:]
203
  decoded_tokens = tokenizer.convert_ids_to_tokens(decoded_ids)
204
- filtered_tokens = [tok for tok in decoded_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
205
 
206
  highlighted = []
207
- for i, tok in enumerate(filtered_tokens):
 
 
 
 
208
  token_str = tokenizer.convert_tokens_to_string([tok])
209
- abs_idx = answer_start + i
210
  if abs_idx in just_noised_indices:
211
  highlighted.append(f'<span style="color:red">{token_str}</span>')
212
  else:
@@ -224,6 +229,7 @@ def diffusion_chat(question, eot_weight, max_it, sharpness, noise_clipping, use_
224
  break
225
 
226
 
 
227
  final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
228
  final_tokens = [tok for tok in final_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
229
  final_output = tokenizer.convert_tokens_to_string(final_tokens)
 
166
  ori_input_tokens, answer_start, threshold=1.0, eot_weight=eot_weight, clustering=clustering
167
  )
168
  last_tokens = []
169
+ prev_decoded_tokens = []
170
 
171
  for i in range(max_it):
172
  print('Generating output')
 
176
  # --- Decode and highlight changed tokens in GREEN ---
177
  decoded_ids = current_tokens[answer_start:]
178
  decoded_tokens = tokenizer.convert_ids_to_tokens(decoded_ids)
 
179
 
180
  highlighted = []
181
+ for j, tok in enumerate(decoded_tokens):
182
  token_str = tokenizer.convert_tokens_to_string([tok])
183
+ if prev_decoded_tokens and j < len(prev_decoded_tokens) and tok != prev_decoded_tokens[j]:
184
  highlighted.append(f'<span style="color:green">{token_str}</span>')
185
  else:
186
  highlighted.append(token_str)
187
 
188
+ prev_decoded_tokens = decoded_tokens
189
  yield f"<b>Iteration {i+1}/{max_it} (after generation):</b><br>" + "".join(highlighted).replace('\n', '<br>')
190
  time.sleep(0.1)
191
 
192
  # --- Apply noising and highlight RED tokens ---
193
  threshold = get_noising_schedule(i, max_it, sharpness=sharpness)
194
  if use_confidence_noising:
195
+ current_tokens = confidence_guided_noising(
196
+ generated_tokens, answer_start, confidences, threshold, eot_weight, noise_clipping
197
+ )
198
+ just_noised_indices = [] # Optional: could extract from confidence scores
199
  else:
200
  current_tokens, just_noised_indices = noisify_answer(
201
  generated_tokens, answer_start, threshold=threshold, eot_weight=eot_weight, clustering=clustering
 
203
 
204
  decoded_ids = current_tokens[answer_start:]
205
  decoded_tokens = tokenizer.convert_ids_to_tokens(decoded_ids)
 
206
 
207
  highlighted = []
208
+ for j, tok in enumerate(decoded_tokens):
209
+ tok_id = tokenizer.convert_tokens_to_ids(tok)
210
+ if tok_id == eot_token_id:
211
+ continue # Skip EOT tokens in display
212
+
213
  token_str = tokenizer.convert_tokens_to_string([tok])
214
+ abs_idx = answer_start + j
215
  if abs_idx in just_noised_indices:
216
  highlighted.append(f'<span style="color:red">{token_str}</span>')
217
  else:
 
229
  break
230
 
231
 
232
+
233
  final_tokens = tokenizer.convert_ids_to_tokens(current_tokens[answer_start:])
234
  final_tokens = [tok for tok in final_tokens if tokenizer.convert_tokens_to_ids(tok) != eot_token_id]
235
  final_output = tokenizer.convert_tokens_to_string(final_tokens)