Guy24 commited on
Commit
95d29eb
·
1 Parent(s): cbfd114

adding application

Browse files
Files changed (1) hide show
  1. app.py +46 -53
app.py CHANGED
@@ -226,58 +226,55 @@ def find_last_token_index(full_ids, word_ids):
226
 
227
 
228
  def analyse_word(model_name: str, extraction_template: str, word: str, patchscopes_template: str):
229
- if PatchscopesRetriever is None:
230
- return (
231
- "<p style='color:red'>❌ Patchscopes library not found. Run:<br/>"
232
- "<code>pip install git+https://github.com/schwartz-lab-NLP/Tokens2Words</code></p>"
233
- )
234
-
235
- model, tokenizer = get_model_and_tokenizer(model_name)
236
 
237
- # Build extraction prompt (where hidden states will be collected)
238
- extraction_prompt ="X"
239
 
240
- # Identify last token position of the *word* inside the prompt IDs
241
- word_token_ids = tokenizer.encode(word, add_special_tokens=False)
242
 
243
- # Instantiate Patchscopes retriever
244
- patch_retriever = PatchscopesRetriever(
245
- model,
246
- tokenizer,
247
- extraction_prompt,
248
- patchscopes_template,
249
- prompt_target_placeholder="X",
250
- )
251
 
252
- # Run retrieval for the word across all layers (one pass)
253
- retrieved_words = patch_retriever.get_hidden_states_and_retrieve_word(
254
- word,
255
- num_tokens_to_generate=len(tokenizer.tokenize(word)),
256
- )[0]
257
-
258
- # Build a table summarising which layers match
259
- records = []
260
- matches = 0
261
- for layer_idx, ret_word in enumerate(retrieved_words):
262
- match = ret_word.strip(" ") == word.strip(" ")
263
- if match:
264
- matches += 1
265
- records.append({"Layer": layer_idx, "Retrieved": ret_word, "Match?": "✓" if match else ""})
266
-
267
- df = pd.DataFrame(records)
268
-
269
- def _style(row):
270
- color = "background-color: lightgreen" if row["Match?"] else ""
271
- return [color] * len(row)
272
-
273
- html_table = df.style.apply(_style, axis=1).hide(axis="index").to_html(escape=False)
274
-
275
- sub_tokens = tokenizer.convert_ids_to_tokens(word_token_ids)
276
- top = (
277
- f"<p><b>Sub‑word tokens:</b> {' , '.join(sub_tokens)}</p>"
278
- f"<p><b>Total matched layers:</b> {matches} / {len(retrieved_words)}</p>"
279
- )
280
- return top + html_table
 
 
281
 
282
 
283
  # ----------------------------- GRADIO UI -------------------------------
@@ -311,8 +308,4 @@ with gr.Blocks(theme="soft") as demo:
311
  )
312
 
313
  if __name__ == "__main__":
314
- try:
315
- demo.launch()
316
- except Exception as e:
317
- print(f"Error launching Gradio app: {e}")
318
- raise
 
226
 
227
 
228
  def analyse_word(model_name: str, extraction_template: str, word: str, patchscopes_template: str):
229
+ try:
230
+ model, tokenizer = get_model_and_tokenizer(model_name)
 
 
 
 
 
231
 
232
+ # Build extraction prompt (where hidden states will be collected)
233
+ extraction_prompt ="X"
234
 
235
+ # Identify last token position of the *word* inside the prompt IDs
236
+ word_token_ids = tokenizer.encode(word, add_special_tokens=False)
237
 
238
+ # Instantiate Patchscopes retriever
239
+ patch_retriever = PatchscopesRetriever(
240
+ model,
241
+ tokenizer,
242
+ extraction_prompt,
243
+ patchscopes_template,
244
+ prompt_target_placeholder="X",
245
+ )
246
 
247
+ # Run retrieval for the word across all layers (one pass)
248
+ retrieved_words = patch_retriever.get_hidden_states_and_retrieve_word(
249
+ word,
250
+ num_tokens_to_generate=len(tokenizer.tokenize(word)),
251
+ )[0]
252
+
253
+ # Build a table summarising which layers match
254
+ records = []
255
+ matches = 0
256
+ for layer_idx, ret_word in enumerate(retrieved_words):
257
+ match = ret_word.strip(" ") == word.strip(" ")
258
+ if match:
259
+ matches += 1
260
+ records.append({"Layer": layer_idx, "Retrieved": ret_word, "Match?": "✓" if match else ""})
261
+
262
+ df = pd.DataFrame(records)
263
+
264
+ def _style(row):
265
+ color = "background-color: lightgreen" if row["Match?"] else ""
266
+ return [color] * len(row)
267
+
268
+ html_table = df.style.apply(_style, axis=1).hide(axis="index").to_html(escape=False)
269
+
270
+ sub_tokens = tokenizer.convert_ids_to_tokens(word_token_ids)
271
+ top = (
272
+ f"<p><b>Sub‑word tokens:</b> {' , '.join(sub_tokens)}</p>"
273
+ f"<p><b>Total matched layers:</b> {matches} / {len(retrieved_words)}</p>"
274
+ )
275
+ return top + html_table
276
+ except Exception as e:
277
+ return f"<p style='color:red'>❌ Error: {str(e)}</p>"
278
 
279
 
280
  # ----------------------------- GRADIO UI -------------------------------
 
308
  )
309
 
310
  if __name__ == "__main__":
311
+ demo.launch()