Update app.py
Browse files
app.py
CHANGED
@@ -42,6 +42,7 @@ def visualize_transformer(model_name, sentence):
|
|
42 |
|
43 |
if "t5" in model_name:
|
44 |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, output_attentions=True)
|
|
|
45 |
inputs = tokenizer(sentence, return_tensors='pt')
|
46 |
elif "gpt2" in model_name:
|
47 |
model = GPT2Model.from_pretrained(model_name, output_attentions=True)
|
@@ -65,8 +66,8 @@ def visualize_transformer(model_name, sentence):
|
|
65 |
plt.xticks(rotation=90)
|
66 |
plt.yticks(rotation=0)
|
67 |
|
68 |
-
token_output = [f"{i}: \"{tok}\"" for i, tok in enumerate(tokens)]
|
69 |
-
token_output_str = "[
|
70 |
|
71 |
model_info = MODEL_INFO.get(model_name, {})
|
72 |
details = f"""
|
|
|
42 |
|
43 |
if "t5" in model_name:
|
44 |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, output_attentions=True)
|
45 |
+
sentence = "translate English to English: " + sentence
|
46 |
inputs = tokenizer(sentence, return_tensors='pt')
|
47 |
elif "gpt2" in model_name:
|
48 |
model = GPT2Model.from_pretrained(model_name, output_attentions=True)
|
|
|
66 |
plt.xticks(rotation=90)
|
67 |
plt.yticks(rotation=0)
|
68 |
|
69 |
+
token_output = [f"{i + 1}: \"{tok}\"" for i, tok in enumerate(tokens)]
|
70 |
+
token_output_str = "[\n" + "\n".join(token_output) + "\n]"
|
71 |
|
72 |
model_info = MODEL_INFO.get(model_name, {})
|
73 |
details = f"""
|