marianvd-01 commited on
Commit
5e06d00
·
verified ·
1 Parent(s): 59d5a80

Create visualize.py

Browse files
Files changed (1) hide show
  1. visualize.py +90 -0
visualize.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # visualize.py - Contains functions to draw:
2
+
3
+ #Attention matrix
4
+ #Tokenization preview
5
+ #Embedding heatmaps
6
+ #Model comparison chart
7
+
8
+ import matplotlib.pyplot as plt
9
+ import seaborn as sns
10
+ import numpy as np
11
+ import torch
12
+ from sklearn.decomposition import PCA
13
+
14
+
15
+ def plot_attention(tokens, attn_matrix):
16
+ fig, ax = plt.subplots(figsize=(8, 6))
17
+ cax = ax.matshow(attn_matrix, cmap="viridis")
18
+ fig.colorbar(cax)
19
+ ax.set_xticks(range(len(tokens)))
20
+ ax.set_yticks(range(len(tokens)))
21
+ ax.set_xticklabels(tokens, rotation=90)
22
+ ax.set_yticklabels(tokens)
23
+ ax.set_title("Attention Map")
24
+ plt.tight_layout()
25
+ return fig
26
+
27
+
28
+ def visualize_attention(tokenizer, model, text, layer_index, head_index):
29
+ inputs = tokenizer(text, return_tensors="pt")
30
+ with torch.no_grad():
31
+ outputs = model(**inputs)
32
+
33
+ tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
34
+ attn = outputs.attentions[layer_index][0, head_index].detach().numpy()
35
+ return plot_attention(tokens, attn)
36
+
37
+
38
+ def show_tokenization(tokenizer, text):
39
+ tokens = tokenizer.tokenize(text)
40
+ fig, ax = plt.subplots(figsize=(8, 1))
41
+ ax.imshow([[0] * len(tokens)], cmap="Pastel2", aspect="auto")
42
+ ax.set_xticks(range(len(tokens)))
43
+ ax.set_xticklabels(tokens, rotation=90)
44
+ ax.set_yticks([])
45
+ ax.set_title("Tokenization")
46
+ return fig
47
+
48
+
49
+ def show_embeddings(tokenizer, model, text):
50
+ inputs = tokenizer(text, return_tensors="pt")
51
+ with torch.no_grad():
52
+ outputs = model(**inputs)
53
+
54
+ embeddings = outputs.last_hidden_state[0].detach().numpy()
55
+ pca = PCA(n_components=2)
56
+ reduced = pca.fit_transform(embeddings)
57
+
58
+ tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
59
+ fig, ax = plt.subplots()
60
+ ax.scatter(reduced[:, 0], reduced[:, 1])
61
+
62
+ for i, token in enumerate(tokens):
63
+ ax.annotate(token, (reduced[i, 0], reduced[i, 1]))
64
+
65
+ ax.set_title("Token Embeddings (PCA)")
66
+ return fig
67
+
68
+
69
+ def compare_model_sizes():
70
+ from model_utils import MODEL_OPTIONS
71
+ from transformers import AutoModel
72
+
73
+ model_names = list(MODEL_OPTIONS.values())
74
+ sizes = []
75
+
76
+ for name in model_names:
77
+ try:
78
+ model = AutoModel.from_pretrained(name)
79
+ size = sum(p.numel() for p in model.parameters()) / 1e6 # in millions
80
+ sizes.append(size)
81
+ except:
82
+ sizes.append(None)
83
+
84
+ fig, ax = plt.subplots()
85
+ ax.bar(list(MODEL_OPTIONS.keys()), sizes, color="skyblue")
86
+ ax.set_ylabel("Parameters (Millions)")
87
+ ax.set_title("Model Size Comparison")
88
+ ax.tick_params(axis='x', rotation=45)
89
+ plt.tight_layout()
90
+ return fig