Aranwer commited on
Commit
ad92e07
·
verified ·
1 Parent(s): 70e2621

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -21
app.py CHANGED
@@ -1,46 +1,108 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModel
3
  import torch
4
  import matplotlib.pyplot as plt
5
  import seaborn as sns
6
 
7
- def visualize_attention(model_name, sentence):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  tokenizer = AutoTokenizer.from_pretrained(model_name)
9
- model = AutoModel.from_pretrained(model_name, output_attentions=True)
10
 
11
- inputs = tokenizer(sentence, return_tensors='pt')
12
- outputs = model(**inputs)
13
- attentions = outputs.attentions # tuple of (layer, batch, head, seq_len, seq_len)
 
 
 
 
 
 
 
14
 
 
 
15
  tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
16
-
17
  fig, ax = plt.subplots(figsize=(10, 8))
18
- sns.heatmap(attentions[-1][0][0].detach().numpy(),
19
- xticklabels=tokens,
20
- yticklabels=tokens,
21
- cmap="viridis",
22
  ax=ax)
23
  ax.set_title(f"Attention Map - Layer {len(attentions)} Head 1")
24
  plt.xticks(rotation=90)
25
  plt.yticks(rotation=0)
26
-
27
- return fig
28
 
29
- model_list = [
30
- "bert-base-uncased",
31
- "roberta-base",
32
- "distilbert-base-uncased"
33
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  iface = gr.Interface(
36
- fn=visualize_attention,
37
  inputs=[
38
  gr.Dropdown(choices=model_list, label="Choose Transformer Model"),
39
  gr.Textbox(label="Enter Input Sentence")
40
  ],
41
- outputs=gr.Plot(label="Attention Map"),
 
 
 
42
  title="Transformer Attention Visualizer",
43
- description="Visualize attention heads of transformer models. Select a model and input text to see attention heatmaps."
44
  )
45
 
46
  iface.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM, GPT2Model
3
  import torch
4
  import matplotlib.pyplot as plt
5
  import seaborn as sns
6
 
7
+ MODEL_INFO = {
8
+ "bert-base-uncased": {
9
+ "Model Type": "BERT",
10
+ "Layers": 12,
11
+ "Attention Heads": 12,
12
+ "Parameters": "109.48M"
13
+ },
14
+ "roberta-base": {
15
+ "Model Type": "RoBERTa",
16
+ "Layers": 12,
17
+ "Attention Heads": 12,
18
+ "Parameters": "125M"
19
+ },
20
+ "distilbert-base-uncased": {
21
+ "Model Type": "DistilBERT",
22
+ "Layers": 6,
23
+ "Attention Heads": 12,
24
+ "Parameters": "66M"
25
+ },
26
+ "gpt2": {
27
+ "Model Type": "GPT-2",
28
+ "Layers": 12,
29
+ "Attention Heads": 12,
30
+ "Parameters": "124M"
31
+ },
32
+ "t5-small": {
33
+ "Model Type": "T5",
34
+ "Layers": 6,
35
+ "Attention Heads": 8,
36
+ "Parameters": "60M"
37
+ }
38
+ }
39
+
40
+ def visualize_transformer(model_name, sentence):
41
  tokenizer = AutoTokenizer.from_pretrained(model_name)
 
42
 
43
+ if "t5" in model_name:
44
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name, output_attentions=True)
45
+ inputs = tokenizer(sentence, return_tensors='pt')
46
+ elif "gpt2" in model_name:
47
+ model = GPT2Model.from_pretrained(model_name, output_attentions=True)
48
+ tokenizer.pad_token = tokenizer.eos_token
49
+ inputs = tokenizer(sentence, return_tensors='pt', padding=True)
50
+ else:
51
+ model = AutoModel.from_pretrained(model_name, output_attentions=True)
52
+ inputs = tokenizer(sentence, return_tensors='pt')
53
 
54
+ outputs = model(**inputs)
55
+ attentions = outputs.attentions
56
  tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
57
+
58
  fig, ax = plt.subplots(figsize=(10, 8))
59
+ sns.heatmap(attentions[-1][0][0].detach().numpy(),
60
+ xticklabels=tokens,
61
+ yticklabels=tokens,
62
+ cmap="viridis",
63
  ax=ax)
64
  ax.set_title(f"Attention Map - Layer {len(attentions)} Head 1")
65
  plt.xticks(rotation=90)
66
  plt.yticks(rotation=0)
 
 
67
 
68
+ token_output = [f"{i}: \"{tok}\"" for i, tok in enumerate(tokens)]
69
+ token_output_str = "[\\n" + "\\n".join(token_output) + "\\n]"
70
+
71
+ model_info = MODEL_INFO.get(model_name, {})
72
+ details = f"""
73
+ 🛠 Model Details
74
+ Model Type: {model_info.get("Model Type", "Unknown")}
75
+
76
+ Number of Layers: {model_info.get("Layers", "?" )}
77
+
78
+ Number of Attention Heads: {model_info.get("Attention Heads", "?" )}
79
+
80
+ Total Parameters: {model_info.get("Parameters", "?" )}
81
+
82
+ 📊 Tokenization Visualization
83
+ Enter Text:
84
+ {sentence}
85
+
86
+ Tokenized Output:
87
+ {token_output_str}
88
+ """
89
+
90
+ return details, fig
91
+
92
+ model_list = list(MODEL_INFO.keys())
93
 
94
  iface = gr.Interface(
95
+ fn=visualize_transformer,
96
  inputs=[
97
  gr.Dropdown(choices=model_list, label="Choose Transformer Model"),
98
  gr.Textbox(label="Enter Input Sentence")
99
  ],
100
+ outputs=[
101
+ gr.Textbox(label="🧠 Model + Token Info", lines=20),
102
+ gr.Plot(label="🧩 Attention Map")
103
+ ],
104
  title="Transformer Attention Visualizer",
105
+ description="Visualize attention heads of transformer models with detailed model and token information."
106
  )
107
 
108
  iface.launch()