AryaWu commited on
Commit
725c23f
·
verified ·
1 Parent(s): e152b09

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -12
app.py CHANGED
@@ -18,7 +18,7 @@ CONFIG.set_default_api_key(api_key)
18
  access_token = os.environ['HUGGING_FACE_HUB_TOKEN']
19
 
20
  # Load the Language Model
21
- llama = LanguageModel("meta-llama/Meta-Llama-3.1-8B", token=access_token)
22
 
23
  #placeholder for reset
24
  prompts_with_probs = pd.DataFrame(
@@ -55,7 +55,9 @@ def run_lens(model,PROMPT):
55
  logits_lens_token_result_by_layer.append(logits_lens_next_token)
56
  tokens_out = llama.lm_head.output.argmax(dim=-1).save()
57
  expected_token = tokens_out[0][-1].save()
 
58
  logits_lens_all_probs = np.concatenate([probs[:, expected_token].cpu().detach().to(torch.float32).numpy() for probs in logits_lens_probs_by_layer])
 
59
  #get the rank of the expected token from each layer's distribution
60
  for layer_probs in logits_lens_probs_by_layer:
61
  # Sort the probabilities in descending order and find the rank of the expected token
@@ -113,7 +115,7 @@ def plot_prob(prompts_with_probs):
113
  # Add labels and title
114
  plt.xlabel('Layer Number')
115
  plt.ylabel('Probability of Expected Token')
116
- plt.title('Prob of expected token across layers\n(annotated with decoded output at each layer)')
117
  plt.grid(True)
118
  plt.ylim(0.0, 1.0)
119
  plt.legend(title='Prompts', bbox_to_anchor=(0.5, -0.15), loc='upper center', ncol=1)
@@ -177,6 +179,8 @@ def plot_prob_mean(prompts_with_probs):
177
  plt.title('Mean Probability of Expected Token')
178
  plt.xticks(rotation=45, ha='right')
179
  plt.grid(axis='y')
 
 
180
 
181
  # Annotate the mean and variance on the bars
182
  for bar, mean, var in zip(bars, summary_stats['mean_prob'], summary_stats['variance']):
@@ -277,18 +281,25 @@ def submit_prompts(prompts_data):
277
 
278
  def clear_all(prompts):
279
  prompts=[['']]
 
 
280
  prompts_data = gr.Dataframe(headers=["Prompt"], row_count=5, col_count=1, value= prompts, type="array", interactive=True)
281
- return prompts_data,plot_prob(prompts_with_probs),plot_rank(prompts_with_ranks),plot_prob_mean(prompts_with_probs),plot_rank_mean(prompts_with_ranks)
282
 
283
 
284
  def gradio_interface():
285
  with gr.Blocks(theme="gradio/monochrome") as demo:
286
  prompts=[['']]
287
- prompts_data = gr.Dataframe(headers=["Prompt"], row_count=5, col_count=1, value= prompts, type="array", interactive=True)
288
- prompt_file=gr.File(type="filepath", label="Upload a File with Prompts")
 
 
 
289
  prompt_file.upload(process_file, inputs=[prompts_data,prompt_file], outputs=[prompts_data])
290
-
291
  # Define the outputs
 
 
 
292
  with gr.Row():
293
  prob_visualization = gr.Image(value=plot_prob(prompts_with_probs), type="pil",label=" ")
294
  rank_visualization = gr.Image(value=plot_rank(prompts_with_ranks), type="pil",label=" ")
@@ -296,14 +307,11 @@ def gradio_interface():
296
  prob_mean_visualization = gr.Image(value=plot_prob_mean(prompts_with_probs), type="pil",label=" ")
297
  rank_mean_visualization = gr.Image(value=plot_rank_mean(prompts_with_ranks), type="pil",label=" ")
298
 
299
- with gr.Row():
300
- clear_btn = gr.Button("Clear")
301
- clear_btn.click(clear_all, inputs=[prompts_data], outputs=[prompts_data,prob_visualization,rank_visualization,prob_mean_visualization,rank_mean_visualization])
302
- submit_btn = gr.Button("Submit")
303
- submit_btn.click(submit_prompts, inputs=[prompts_data], outputs=[prob_visualization,rank_visualization,prob_mean_visualization,rank_mean_visualization])#
304
 
305
 
306
  demo.launch()
307
 
308
-
309
  gradio_interface()
 
18
  access_token = os.environ['HUGGING_FACE_HUB_TOKEN']
19
 
20
  # Load the Language Model
21
+ llama = LanguageModel("meta-llama/Meta-Llama-3.1-8B")
22
 
23
  #placeholder for reset
24
  prompts_with_probs = pd.DataFrame(
 
55
  logits_lens_token_result_by_layer.append(logits_lens_next_token)
56
  tokens_out = llama.lm_head.output.argmax(dim=-1).save()
57
  expected_token = tokens_out[0][-1].save()
58
+ # logits_lens_all_probs = np.concatenate([probs[:, expected_token].cpu().detach().numpy() for probs in logits_lens_probs_by_layer])
59
  logits_lens_all_probs = np.concatenate([probs[:, expected_token].cpu().detach().to(torch.float32).numpy() for probs in logits_lens_probs_by_layer])
60
+
61
  #get the rank of the expected token from each layer's distribution
62
  for layer_probs in logits_lens_probs_by_layer:
63
  # Sort the probabilities in descending order and find the rank of the expected token
 
115
  # Add labels and title
116
  plt.xlabel('Layer Number')
117
  plt.ylabel('Probability of Expected Token')
118
+ plt.title('Prob of expected token across layers\n(annotated with actual decoded output at each layer)')
119
  plt.grid(True)
120
  plt.ylim(0.0, 1.0)
121
  plt.legend(title='Prompts', bbox_to_anchor=(0.5, -0.15), loc='upper center', ncol=1)
 
179
  plt.title('Mean Probability of Expected Token')
180
  plt.xticks(rotation=45, ha='right')
181
  plt.grid(axis='y')
182
+ plt.ylim(0, 1)
183
+
184
 
185
  # Annotate the mean and variance on the bars
186
  for bar, mean, var in zip(bars, summary_stats['mean_prob'], summary_stats['variance']):
 
281
 
282
  def clear_all(prompts):
283
  prompts=[['']]
284
+ # prompt_file=gr.File(type="filepath", label="Upload a File with Prompts")
285
+ prompt_file = None
286
  prompts_data = gr.Dataframe(headers=["Prompt"], row_count=5, col_count=1, value= prompts, type="array", interactive=True)
287
+ return prompts_data,prompt_file,plot_prob(prompts_with_probs),plot_rank(prompts_with_ranks),plot_prob_mean(prompts_with_probs),plot_rank_mean(prompts_with_ranks)
288
 
289
 
290
  def gradio_interface():
291
  with gr.Blocks(theme="gradio/monochrome") as demo:
292
  prompts=[['']]
293
+ with gr.Row():
294
+ with gr.Column(scale=3):
295
+ prompts_data = gr.Dataframe(headers=["Prompt"], row_count=5, col_count=1, value= prompts, type="array", interactive=True)
296
+ with gr.Column(scale=1):
297
+ prompt_file=gr.File(type="filepath", label="Upload a File with Prompts")
298
  prompt_file.upload(process_file, inputs=[prompts_data,prompt_file], outputs=[prompts_data])
 
299
  # Define the outputs
300
+ with gr.Row():
301
+ clear_btn = gr.Button("Clear")
302
+ submit_btn = gr.Button("Submit")
303
  with gr.Row():
304
  prob_visualization = gr.Image(value=plot_prob(prompts_with_probs), type="pil",label=" ")
305
  rank_visualization = gr.Image(value=plot_rank(prompts_with_ranks), type="pil",label=" ")
 
307
  prob_mean_visualization = gr.Image(value=plot_prob_mean(prompts_with_probs), type="pil",label=" ")
308
  rank_mean_visualization = gr.Image(value=plot_rank_mean(prompts_with_ranks), type="pil",label=" ")
309
 
310
+ clear_btn.click(clear_all, inputs=[prompts_data], outputs=[prompts_data,prompt_file,prob_visualization,rank_visualization,prob_mean_visualization,rank_mean_visualization])
311
+ submit_btn.click(submit_prompts, inputs=[prompts_data], outputs=[prob_visualization,rank_visualization,prob_mean_visualization,rank_mean_visualization])#
312
+ prompt_file.clear(clear_all, inputs=[prompts_data], outputs=[prompts_data,prompt_file,prob_visualization,rank_visualization,prob_mean_visualization,rank_mean_visualization])
 
 
313
 
314
 
315
  demo.launch()
316
 
 
317
  gradio_interface()