Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -18,7 +18,7 @@ CONFIG.set_default_api_key(api_key)
|
|
18 |
access_token = os.environ['HUGGING_FACE_HUB_TOKEN']
|
19 |
|
20 |
# Load the Language Model
|
21 |
-
llama = LanguageModel("meta-llama/Meta-Llama-3.1-8B"
|
22 |
|
23 |
#placeholder for reset
|
24 |
prompts_with_probs = pd.DataFrame(
|
@@ -55,7 +55,9 @@ def run_lens(model,PROMPT):
|
|
55 |
logits_lens_token_result_by_layer.append(logits_lens_next_token)
|
56 |
tokens_out = llama.lm_head.output.argmax(dim=-1).save()
|
57 |
expected_token = tokens_out[0][-1].save()
|
|
|
58 |
logits_lens_all_probs = np.concatenate([probs[:, expected_token].cpu().detach().to(torch.float32).numpy() for probs in logits_lens_probs_by_layer])
|
|
|
59 |
#get the rank of the expected token from each layer's distribution
|
60 |
for layer_probs in logits_lens_probs_by_layer:
|
61 |
# Sort the probabilities in descending order and find the rank of the expected token
|
@@ -113,7 +115,7 @@ def plot_prob(prompts_with_probs):
|
|
113 |
# Add labels and title
|
114 |
plt.xlabel('Layer Number')
|
115 |
plt.ylabel('Probability of Expected Token')
|
116 |
-
plt.title('Prob of expected token across layers\n(annotated with decoded output at each layer)')
|
117 |
plt.grid(True)
|
118 |
plt.ylim(0.0, 1.0)
|
119 |
plt.legend(title='Prompts', bbox_to_anchor=(0.5, -0.15), loc='upper center', ncol=1)
|
@@ -177,6 +179,8 @@ def plot_prob_mean(prompts_with_probs):
|
|
177 |
plt.title('Mean Probability of Expected Token')
|
178 |
plt.xticks(rotation=45, ha='right')
|
179 |
plt.grid(axis='y')
|
|
|
|
|
180 |
|
181 |
# Annotate the mean and variance on the bars
|
182 |
for bar, mean, var in zip(bars, summary_stats['mean_prob'], summary_stats['variance']):
|
@@ -277,18 +281,25 @@ def submit_prompts(prompts_data):
|
|
277 |
|
278 |
def clear_all(prompts):
|
279 |
prompts=[['']]
|
|
|
|
|
280 |
prompts_data = gr.Dataframe(headers=["Prompt"], row_count=5, col_count=1, value= prompts, type="array", interactive=True)
|
281 |
-
return prompts_data,plot_prob(prompts_with_probs),plot_rank(prompts_with_ranks),plot_prob_mean(prompts_with_probs),plot_rank_mean(prompts_with_ranks)
|
282 |
|
283 |
|
284 |
def gradio_interface():
|
285 |
with gr.Blocks(theme="gradio/monochrome") as demo:
|
286 |
prompts=[['']]
|
287 |
-
|
288 |
-
|
|
|
|
|
|
|
289 |
prompt_file.upload(process_file, inputs=[prompts_data,prompt_file], outputs=[prompts_data])
|
290 |
-
|
291 |
# Define the outputs
|
|
|
|
|
|
|
292 |
with gr.Row():
|
293 |
prob_visualization = gr.Image(value=plot_prob(prompts_with_probs), type="pil",label=" ")
|
294 |
rank_visualization = gr.Image(value=plot_rank(prompts_with_ranks), type="pil",label=" ")
|
@@ -296,14 +307,11 @@ def gradio_interface():
|
|
296 |
prob_mean_visualization = gr.Image(value=plot_prob_mean(prompts_with_probs), type="pil",label=" ")
|
297 |
rank_mean_visualization = gr.Image(value=plot_rank_mean(prompts_with_ranks), type="pil",label=" ")
|
298 |
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
submit_btn = gr.Button("Submit")
|
303 |
-
submit_btn.click(submit_prompts, inputs=[prompts_data], outputs=[prob_visualization,rank_visualization,prob_mean_visualization,rank_mean_visualization])#
|
304 |
|
305 |
|
306 |
demo.launch()
|
307 |
|
308 |
-
|
309 |
gradio_interface()
|
|
|
18 |
access_token = os.environ['HUGGING_FACE_HUB_TOKEN']
|
19 |
|
20 |
# Load the Language Model
|
21 |
+
llama = LanguageModel("meta-llama/Meta-Llama-3.1-8B")
|
22 |
|
23 |
#placeholder for reset
|
24 |
prompts_with_probs = pd.DataFrame(
|
|
|
55 |
logits_lens_token_result_by_layer.append(logits_lens_next_token)
|
56 |
tokens_out = llama.lm_head.output.argmax(dim=-1).save()
|
57 |
expected_token = tokens_out[0][-1].save()
|
58 |
+
# logits_lens_all_probs = np.concatenate([probs[:, expected_token].cpu().detach().numpy() for probs in logits_lens_probs_by_layer])
|
59 |
logits_lens_all_probs = np.concatenate([probs[:, expected_token].cpu().detach().to(torch.float32).numpy() for probs in logits_lens_probs_by_layer])
|
60 |
+
|
61 |
#get the rank of the expected token from each layer's distribution
|
62 |
for layer_probs in logits_lens_probs_by_layer:
|
63 |
# Sort the probabilities in descending order and find the rank of the expected token
|
|
|
115 |
# Add labels and title
|
116 |
plt.xlabel('Layer Number')
|
117 |
plt.ylabel('Probability of Expected Token')
|
118 |
+
plt.title('Prob of expected token across layers\n(annotated with actual decoded output at each layer)')
|
119 |
plt.grid(True)
|
120 |
plt.ylim(0.0, 1.0)
|
121 |
plt.legend(title='Prompts', bbox_to_anchor=(0.5, -0.15), loc='upper center', ncol=1)
|
|
|
179 |
plt.title('Mean Probability of Expected Token')
|
180 |
plt.xticks(rotation=45, ha='right')
|
181 |
plt.grid(axis='y')
|
182 |
+
plt.ylim(0, 1)
|
183 |
+
|
184 |
|
185 |
# Annotate the mean and variance on the bars
|
186 |
for bar, mean, var in zip(bars, summary_stats['mean_prob'], summary_stats['variance']):
|
|
|
281 |
|
282 |
def clear_all(prompts):
|
283 |
prompts=[['']]
|
284 |
+
# prompt_file=gr.File(type="filepath", label="Upload a File with Prompts")
|
285 |
+
prompt_file = None
|
286 |
prompts_data = gr.Dataframe(headers=["Prompt"], row_count=5, col_count=1, value= prompts, type="array", interactive=True)
|
287 |
+
return prompts_data,prompt_file,plot_prob(prompts_with_probs),plot_rank(prompts_with_ranks),plot_prob_mean(prompts_with_probs),plot_rank_mean(prompts_with_ranks)
|
288 |
|
289 |
|
290 |
def gradio_interface():
|
291 |
with gr.Blocks(theme="gradio/monochrome") as demo:
|
292 |
prompts=[['']]
|
293 |
+
with gr.Row():
|
294 |
+
with gr.Column(scale=3):
|
295 |
+
prompts_data = gr.Dataframe(headers=["Prompt"], row_count=5, col_count=1, value= prompts, type="array", interactive=True)
|
296 |
+
with gr.Column(scale=1):
|
297 |
+
prompt_file=gr.File(type="filepath", label="Upload a File with Prompts")
|
298 |
prompt_file.upload(process_file, inputs=[prompts_data,prompt_file], outputs=[prompts_data])
|
|
|
299 |
# Define the outputs
|
300 |
+
with gr.Row():
|
301 |
+
clear_btn = gr.Button("Clear")
|
302 |
+
submit_btn = gr.Button("Submit")
|
303 |
with gr.Row():
|
304 |
prob_visualization = gr.Image(value=plot_prob(prompts_with_probs), type="pil",label=" ")
|
305 |
rank_visualization = gr.Image(value=plot_rank(prompts_with_ranks), type="pil",label=" ")
|
|
|
307 |
prob_mean_visualization = gr.Image(value=plot_prob_mean(prompts_with_probs), type="pil",label=" ")
|
308 |
rank_mean_visualization = gr.Image(value=plot_rank_mean(prompts_with_ranks), type="pil",label=" ")
|
309 |
|
310 |
+
clear_btn.click(clear_all, inputs=[prompts_data], outputs=[prompts_data,prompt_file,prob_visualization,rank_visualization,prob_mean_visualization,rank_mean_visualization])
|
311 |
+
submit_btn.click(submit_prompts, inputs=[prompts_data], outputs=[prob_visualization,rank_visualization,prob_mean_visualization,rank_mean_visualization])#
|
312 |
+
prompt_file.clear(clear_all, inputs=[prompts_data], outputs=[prompts_data,prompt_file,prob_visualization,rank_visualization,prob_mean_visualization,rank_mean_visualization])
|
|
|
|
|
313 |
|
314 |
|
315 |
demo.launch()
|
316 |
|
|
|
317 |
gradio_interface()
|