Aranwer commited on
Commit
005c98d
·
verified ·
1 Parent(s): ad92e07

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -42,6 +42,7 @@ def visualize_transformer(model_name, sentence):
42
 
43
  if "t5" in model_name:
44
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name, output_attentions=True)
 
45
  inputs = tokenizer(sentence, return_tensors='pt')
46
  elif "gpt2" in model_name:
47
  model = GPT2Model.from_pretrained(model_name, output_attentions=True)
@@ -65,8 +66,8 @@ def visualize_transformer(model_name, sentence):
65
  plt.xticks(rotation=90)
66
  plt.yticks(rotation=0)
67
 
68
- token_output = [f"{i}: \"{tok}\"" for i, tok in enumerate(tokens)]
69
- token_output_str = "[\\n" + "\\n".join(token_output) + "\\n]"
70
 
71
  model_info = MODEL_INFO.get(model_name, {})
72
  details = f"""
 
42
 
43
  if "t5" in model_name:
44
  model = AutoModelForSeq2SeqLM.from_pretrained(model_name, output_attentions=True)
45
+ sentence = "translate English to English: " + sentence
46
  inputs = tokenizer(sentence, return_tensors='pt')
47
  elif "gpt2" in model_name:
48
  model = GPT2Model.from_pretrained(model_name, output_attentions=True)
 
66
  plt.xticks(rotation=90)
67
  plt.yticks(rotation=0)
68
 
69
+ token_output = [f"{i + 1}: \"{tok}\"" for i, tok in enumerate(tokens)]
70
+ token_output_str = "[\n" + "\n".join(token_output) + "\n]"
71
 
72
  model_info = MODEL_INFO.get(model_name, {})
73
  details = f"""