rahideer commited on
Commit
14ee668
·
verified ·
1 Parent(s): b1bec5c

Update utils.py

Browse files
Files changed (1) hide show
  1. utils.py +20 -2
utils.py CHANGED
@@ -1,5 +1,6 @@
1
  import plotly.graph_objects as go
2
  import numpy as np
 
3
 
4
  def list_supported_models(task):
5
  if task == "Text Classification":
@@ -12,9 +13,9 @@ def list_supported_models(task):
12
 
13
  def visualize_attention(attentions, tokenizer, inputs):
14
  tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
15
- last_layer_attention = attentions[-1][0] # shape: [num_heads, seq_len, seq_len]
16
  avg_attention = last_layer_attention.mean(dim=0).detach().numpy()
17
-
18
  fig = go.Figure(data=go.Heatmap(
19
  z=avg_attention,
20
  x=tokens,
@@ -23,3 +24,20 @@ def visualize_attention(attentions, tokenizer, inputs):
23
  ))
24
  fig.update_layout(title="Average Attention - Last Layer", xaxis_nticks=len(tokens))
25
  return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import plotly.graph_objects as go
2
  import numpy as np
3
+ from sklearn.decomposition import PCA
4
 
5
  def list_supported_models(task):
6
  if task == "Text Classification":
 
13
 
14
  def visualize_attention(attentions, tokenizer, inputs):
15
  tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
16
+ last_layer_attention = attentions[-1][0] # [heads, seq_len, seq_len]
17
  avg_attention = last_layer_attention.mean(dim=0).detach().numpy()
18
+
19
  fig = go.Figure(data=go.Heatmap(
20
  z=avg_attention,
21
  x=tokens,
 
24
  ))
25
  fig.update_layout(title="Average Attention - Last Layer", xaxis_nticks=len(tokens))
26
  return fig
27
+
28
+ def plot_token_embeddings(embeddings, tokens):
29
+ pca = PCA(n_components=2)
30
+ reduced = pca.fit_transform(embeddings.detach().numpy())
31
+
32
+ fig = go.Figure()
33
+ for i, token in enumerate(tokens):
34
+ fig.add_trace(go.Scatter(
35
+ x=[reduced[i][0]], y=[reduced[i][1]],
36
+ text=[token],
37
+ mode='markers+text',
38
+ textposition='top center',
39
+ marker=dict(size=10),
40
+ name=token
41
+ ))
42
+ fig.update_layout(title="Token Embeddings (PCA)", xaxis_title="PC 1", yaxis_title="PC 2")
43
+ return fig