Spaces:
Sleeping
Sleeping
File size: 3,605 Bytes
65b683f 325ca0f 65b683f 9e882df 325ca0f 9e882df 325ca0f 9e882df 325ca0f 9e882df 325ca0f 9e882df 325ca0f 65b683f abc9e3b 65b683f abc9e3b 65b683f 5c41bd3 65b683f 206be88 abc9e3b 9e882df 65b683f 325ca0f 65b683f abc9e3b 65b683f abc9e3b 325ca0f abc9e3b 65b683f 9e882df 65b683f 47f3f6e 325ca0f abc9e3b 47f3f6e 325ca0f 47f3f6e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import streamlit as st
if not hasattr(st, "cache_resource"):
st.cache_resource = st.experimental_singleton
import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
import pandas as pd
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from transformers import MarianMTModel, MarianTokenizer
model_options = [
'Helsinki-NLP/opus-mt-roa-en',
'Helsinki-NLP/opus-mt-en-roa',
]
col1, col2 = st.columns(2)
with col1:
model_name = st.selectbox("Select a model", model_options + ['other'])
if model_name == 'other':
model_name = st.text_input("Enter model name", model_options[0])
@st.cache_resource
def get_tokenizer(model_name):
return MarianTokenizer.from_pretrained(model_name)
@st.cache_resource
def get_model(model_name):
model = MarianMTModel.from_pretrained(model_name).to(device)
print(f"Loaded model, {model.num_parameters():,d} parameters.")
return model
tokenizer = get_tokenizer(model_name)
model = get_model(model_name)
if tokenizer.supported_language_codes:
lang_code = st.selectbox("Select a language", tokenizer.supported_language_codes)
else:
lang_code = None
with col2:
input_text = st.text_input("Enter text to translate", "Hola, mi nombre es Juan")
input_text = input_text.strip()
if not input_text:
st.stop()
# prepend the language code if necessary
if lang_code:
input_text = f"{lang_code} {input_text}"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)
example_generations = model.generate(
input_ids,
num_beams=4,
num_return_sequences=4,
)
col1, col2 = st.columns(2)
with col1:
st.write("Example generations:")
st.write('\n'.join(
'- ' + translation
for translation in tokenizer.batch_decode(example_generations, skip_special_tokens=True)))
with col2:
example_first_word = tokenizer.decode(example_generations[0, 1])
output_so_far = st.text_input("Enter text translated so far", example_first_word)
# tokenize the output so far
with tokenizer.as_target_tokenizer():
output_tokens = tokenizer.tokenize(output_so_far)
decoder_input_ids = tokenizer.convert_tokens_to_ids(output_tokens)
# Add the start token
decoder_input_ids = [model.config.decoder_start_token_id] + decoder_input_ids
with torch.no_grad():
model_output = model(
input_ids = input_ids,
decoder_input_ids = torch.tensor([decoder_input_ids]).to(device))
last_token_logits = model_output.logits[0, -1].cpu()
assert len(last_token_logits.shape) == 1
most_likely_tokens = last_token_logits.topk(k=20)
probs = last_token_logits.softmax(dim=-1)
probs_for_likely_tokens = probs[most_likely_tokens.indices]
with tokenizer.as_target_tokenizer():
probs_table = pd.DataFrame({
'token': [tokenizer.decode(token_id) for token_id in most_likely_tokens.indices],
'id': most_likely_tokens.indices,
'probability': probs_for_likely_tokens,
'logprob': probs_for_likely_tokens.log(),
'cumulative probability': probs_for_likely_tokens.cumsum(0)
})
st.subheader("Most likely next tokens")
st.table(probs_table.style.hide(axis='index'))
if len(decoder_input_ids) > 1:
st.subheader("Loss by already-generated token")
loss_table = pd.DataFrame({
'token': [tokenizer.decode(token_id) for token_id in decoder_input_ids[1:]],
'loss': F.cross_entropy(model_output.logits[0, :-1], torch.tensor(decoder_input_ids[1:]).to(device), reduction='none').cpu()
})
st.write(loss_table)
st.write("Total loss so far:", loss_table.loss.sum()) |