Spaces:
Running
Running
File size: 1,530 Bytes
b36e408 14ee668 b36e408 14ee668 b36e408 14ee668 b36e408 14ee668 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 |
import plotly.graph_objects as go
import numpy as np
from sklearn.decomposition import PCA
def list_supported_models(task):
if task == "Text Classification":
return ["distilbert-base-uncased", "bert-base-uncased", "roberta-base"]
elif task == "Text Generation":
return ["gpt2", "distilgpt2"]
elif task == "Question Answering":
return ["deepset/roberta-base-squad2", "distilbert-base-cased-distilled-squad"]
return []
def visualize_attention(attentions, tokenizer, inputs):
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
last_layer_attention = attentions[-1][0] # [heads, seq_len, seq_len]
avg_attention = last_layer_attention.mean(dim=0).detach().numpy()
fig = go.Figure(data=go.Heatmap(
z=avg_attention,
x=tokens,
y=tokens,
colorscale='Viridis'
))
fig.update_layout(title="Average Attention - Last Layer", xaxis_nticks=len(tokens))
return fig
def plot_token_embeddings(embeddings, tokens):
pca = PCA(n_components=2)
reduced = pca.fit_transform(embeddings.detach().numpy())
fig = go.Figure()
for i, token in enumerate(tokens):
fig.add_trace(go.Scatter(
x=[reduced[i][0]], y=[reduced[i][1]],
text=[token],
mode='markers+text',
textposition='top center',
marker=dict(size=10),
name=token
))
fig.update_layout(title="Token Embeddings (PCA)", xaxis_title="PC 1", yaxis_title="PC 2")
return fig
|