File size: 3,605 Bytes
65b683f
 
325ca0f
 
 
65b683f
 
 
 
 
 
 
 
 
9e882df
325ca0f
9e882df
 
325ca0f
9e882df
325ca0f
9e882df
325ca0f
 
9e882df
325ca0f
 
65b683f
abc9e3b
65b683f
 
 
abc9e3b
65b683f
5c41bd3
 
 
65b683f
 
 
206be88
abc9e3b
 
 
 
9e882df
65b683f
325ca0f
 
 
65b683f
 
 
 
abc9e3b
 
 
 
65b683f
 
 
abc9e3b
 
 
 
 
325ca0f
 
 
 
 
 
 
 
 
 
 
 
abc9e3b
65b683f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9e882df
65b683f
 
 
 
 
 
 
 
 
 
 
 
 
47f3f6e
325ca0f
abc9e3b
47f3f6e
325ca0f
47f3f6e
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import streamlit as st

if not hasattr(st, "cache_resource"):
    st.cache_resource = st.experimental_singleton

import torch
import torch.nn as nn
import torch.nn.functional as F
import transformers
import pandas as pd

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

from transformers import MarianMTModel, MarianTokenizer

model_options = [
    'Helsinki-NLP/opus-mt-roa-en',
    'Helsinki-NLP/opus-mt-en-roa',
]

col1, col2 = st.columns(2)

with col1:
    model_name = st.selectbox("Select a model", model_options + ['other'])

    if model_name == 'other':
        model_name = st.text_input("Enter model name", model_options[0])

@st.cache_resource
def get_tokenizer(model_name):
    return MarianTokenizer.from_pretrained(model_name)

@st.cache_resource
def get_model(model_name):
    model = MarianMTModel.from_pretrained(model_name).to(device)
    print(f"Loaded model, {model.num_parameters():,d} parameters.")
    return model

tokenizer = get_tokenizer(model_name)
model = get_model(model_name)

if tokenizer.supported_language_codes:
    lang_code = st.selectbox("Select a language", tokenizer.supported_language_codes)
else:
    lang_code = None


with col2:
    input_text = st.text_input("Enter text to translate", "Hola, mi nombre es Juan")

input_text = input_text.strip()
if not input_text:
    st.stop()

# prepend the language code if necessary
if lang_code:
    input_text = f"{lang_code} {input_text}"


input_ids = tokenizer(input_text, return_tensors="pt").input_ids.to(device)

example_generations = model.generate(
    input_ids,
    num_beams=4,
    num_return_sequences=4,
)

col1, col2 = st.columns(2)
with col1:
    st.write("Example generations:")
    st.write('\n'.join(
        '- ' + translation
        for translation in tokenizer.batch_decode(example_generations, skip_special_tokens=True)))

with col2:
    example_first_word = tokenizer.decode(example_generations[0, 1])
    output_so_far = st.text_input("Enter text translated so far", example_first_word)


# tokenize the output so far
with tokenizer.as_target_tokenizer():
    output_tokens = tokenizer.tokenize(output_so_far)
    decoder_input_ids = tokenizer.convert_tokens_to_ids(output_tokens)

# Add the start token
decoder_input_ids = [model.config.decoder_start_token_id] + decoder_input_ids

with torch.no_grad():
    model_output = model(
        input_ids = input_ids,
        decoder_input_ids = torch.tensor([decoder_input_ids]).to(device))

last_token_logits = model_output.logits[0, -1].cpu()
assert len(last_token_logits.shape) == 1
most_likely_tokens = last_token_logits.topk(k=20)

probs = last_token_logits.softmax(dim=-1)
probs_for_likely_tokens = probs[most_likely_tokens.indices]

with tokenizer.as_target_tokenizer():
    probs_table = pd.DataFrame({
        'token': [tokenizer.decode(token_id) for token_id in most_likely_tokens.indices],
        'id': most_likely_tokens.indices,
        'probability': probs_for_likely_tokens,
        'logprob': probs_for_likely_tokens.log(),
        'cumulative probability': probs_for_likely_tokens.cumsum(0)
    })

st.subheader("Most likely next tokens")
st.table(probs_table.style.hide(axis='index'))

if len(decoder_input_ids) > 1:
    st.subheader("Loss by already-generated token")
    loss_table = pd.DataFrame({
        'token': [tokenizer.decode(token_id) for token_id in decoder_input_ids[1:]],
        'loss': F.cross_entropy(model_output.logits[0, :-1], torch.tensor(decoder_input_ids[1:]).to(device), reduction='none').cpu()
    })
    st.write(loss_table)
    st.write("Total loss so far:", loss_table.loss.sum())