Aranwer commited on
Commit
d2ec4c0
·
verified ·
1 Parent(s): 85a6306

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -0
app.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModel
3
+ import torch
4
+ import matplotlib.pyplot as plt
5
+ import seaborn as sns
6
+
7
+ def visualize_attention(model_name, sentence):
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ model = AutoModel.from_pretrained(model_name, output_attentions=True)
10
+
11
+ inputs = tokenizer(sentence, return_tensors='pt')
12
+ outputs = model(**inputs)
13
+ attentions = outputs.attentions # tuple of (layer, batch, head, seq_len, seq_len)
14
+
15
+ tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
16
+
17
+ fig, ax = plt.subplots(figsize=(10, 8))
18
+ sns.heatmap(attentions[-1][0][0].detach().numpy(),
19
+ xticklabels=tokens,
20
+ yticklabels=tokens,
21
+ cmap="viridis",
22
+ ax=ax)
23
+ ax.set_title(f"Attention Map - Layer {len(attentions)} Head 1")
24
+ plt.xticks(rotation=90)
25
+ plt.yticks(rotation=0)
26
+
27
+ return fig
28
+
29
+ model_list = [
30
+ "bert-base-uncased",
31
+ "roberta-base",
32
+ "distilbert-base-uncased"
33
+ ]
34
+
35
+ iface = gr.Interface(
36
+ fn=visualize_attention,
37
+ inputs=[
38
+ gr.Dropdown(choices=model_list, label="Choose Transformer Model"),
39
+ gr.Textbox(label="Enter Input Sentence")
40
+ ],
41
+ outputs=gr.Plot(label="Attention Map"),
42
+ title="Transformer Attention Visualizer",
43
+ description="Visualize attention heads of transformer models. Select a model and input text to see attention heatmaps."
44
+ )
45
+
46
+ iface.launch()