Spaces:
Runtime error
Runtime error
fixed the tokenizer
Browse files- app/app.py +21 -17
app/app.py
CHANGED
@@ -62,17 +62,8 @@ model_type = st.sidebar.selectbox('Model', (MODELS.keys()))
|
|
62 |
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
|
63 |
def get_generator(model_name: str):
|
64 |
st.write(f"Loading the GPT2 model {model_name}, please wait...")
|
65 |
-
special_tokens = AbstractDataset.special_tokens
|
66 |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_auth_token)
|
67 |
-
tokenizer.
|
68 |
-
config = AutoConfig.from_pretrained(model_name,
|
69 |
-
bos_token_id=tokenizer.bos_token_id,
|
70 |
-
eos_token_id=tokenizer.eos_token_id,
|
71 |
-
sep_token_id=tokenizer.sep_token_id,
|
72 |
-
pad_token_id=tokenizer.pad_token_id,
|
73 |
-
output_hidden_states=False,
|
74 |
-
use_auth_token=hf_auth_token)
|
75 |
-
model = GPT2LMHeadModel.from_pretrained(model_name, config=config, use_auth_token=hf_auth_token)
|
76 |
model.resize_token_embeddings(len(tokenizer))
|
77 |
return model, tokenizer
|
78 |
|
@@ -81,24 +72,35 @@ def get_generator(model_name: str):
|
|
81 |
# @st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: id})
|
82 |
def process(text_generator, tokenizer, title: str, keywords: str, text: str,
|
83 |
max_length: int = 200, do_sample: bool = True, top_k: int = 50, top_p: float = 0.95,
|
84 |
-
temperature: float = 1.0, max_time: float = 120.0, seed=42, repetition_penalty=1.0
|
|
|
85 |
# st.write("Cache miss: process")
|
86 |
set_seed(seed)
|
87 |
if repetition_penalty == 0.0:
|
88 |
min_penalty = 1.05
|
89 |
max_penalty = 1.5
|
90 |
repetition_penalty = max(min_penalty + (1.0 - temperature) * (max_penalty - min_penalty), 0.8)
|
91 |
-
print("title:", title)
|
92 |
-
print("keywords:", keywords)
|
93 |
prompt = f"title: {title}\nkeywords: {keywords}\n{text}"
|
94 |
-
print("prompt: ", prompt)
|
95 |
|
96 |
generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
|
97 |
# device = torch.device("cuda")
|
98 |
# generated = generated.to(device)
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
text_generator.eval()
|
101 |
sample_outputs = text_generator.generate(generated,
|
|
|
102 |
do_sample=do_sample,
|
103 |
min_length=200,
|
104 |
max_length=max_length,
|
@@ -109,7 +111,7 @@ def process(text_generator, tokenizer, title: str, keywords: str, text: str,
|
|
109 |
num_return_sequences=1
|
110 |
)
|
111 |
result = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
|
112 |
-
print(f"result: {result}")
|
113 |
prefix_length = len(title) + len(keywords) + len("title: keywords: ") + 2
|
114 |
result = result[prefix_length:]
|
115 |
return result
|
@@ -173,6 +175,7 @@ if prompt_group_name in ["Indonesian Newspaper"]:
|
|
173 |
top_k = 30
|
174 |
top_p = 0.95
|
175 |
repetition_penalty = 0.0
|
|
|
176 |
|
177 |
if decoding_methods == "Beam Search":
|
178 |
do_sample = False
|
@@ -191,7 +194,7 @@ if prompt_group_name in ["Indonesian Newspaper"]:
|
|
191 |
)
|
192 |
else:
|
193 |
do_sample = False
|
194 |
-
repetition_penalty = 1.
|
195 |
penalty_alpha = st.sidebar.number_input(
|
196 |
"Penalty alpha",
|
197 |
value=0.6,
|
@@ -237,11 +240,12 @@ if prompt_group_name in ["Indonesian Newspaper"]:
|
|
237 |
title=session_state.title,
|
238 |
keywords=session_state.keywords,
|
239 |
text=session_state.text, max_length=int(max_length),
|
240 |
-
temperature=temperature, do_sample=do_sample,
|
241 |
top_k=int(top_k), top_p=float(top_p), seed=seed, repetition_penalty=repetition_penalty)
|
242 |
time_end = time.time()
|
243 |
time_diff = time_end - time_start
|
244 |
# result = result[0]["generated_text"]
|
|
|
245 |
st.write(result.replace("\n", " \n"))
|
246 |
st.text("Translation")
|
247 |
translation = translate(result, "en", "id")
|
|
|
62 |
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
|
63 |
def get_generator(model_name: str):
|
64 |
st.write(f"Loading the GPT2 model {model_name}, please wait...")
|
|
|
65 |
tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_auth_token)
|
66 |
+
model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id, use_auth_token=hf_auth_token)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
model.resize_token_embeddings(len(tokenizer))
|
68 |
return model, tokenizer
|
69 |
|
|
|
72 |
# @st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: id})
|
73 |
def process(text_generator, tokenizer, title: str, keywords: str, text: str,
|
74 |
max_length: int = 200, do_sample: bool = True, top_k: int = 50, top_p: float = 0.95,
|
75 |
+
temperature: float = 1.0, max_time: float = 120.0, seed=42, repetition_penalty=1.0,
|
76 |
+
penalty_alpha = 0.6):
|
77 |
# st.write("Cache miss: process")
|
78 |
set_seed(seed)
|
79 |
if repetition_penalty == 0.0:
|
80 |
min_penalty = 1.05
|
81 |
max_penalty = 1.5
|
82 |
repetition_penalty = max(min_penalty + (1.0 - temperature) * (max_penalty - min_penalty), 0.8)
|
83 |
+
# print("title:", title)
|
84 |
+
# print("keywords:", keywords)
|
85 |
prompt = f"title: {title}\nkeywords: {keywords}\n{text}"
|
86 |
+
# print("prompt: ", prompt)
|
87 |
|
88 |
generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
|
89 |
# device = torch.device("cuda")
|
90 |
# generated = generated.to(device)
|
91 |
|
92 |
+
print("do_sample:", do_sample)
|
93 |
+
print("penalty_alpha:", penalty_alpha)
|
94 |
+
print("max_length:", max_length)
|
95 |
+
print("top_k:", top_k)
|
96 |
+
print("top_p:", top_p)
|
97 |
+
print("temperature:", temperature)
|
98 |
+
print("max_time:", max_time)
|
99 |
+
print("repetition_penalty:", repetition_penalty)
|
100 |
+
|
101 |
text_generator.eval()
|
102 |
sample_outputs = text_generator.generate(generated,
|
103 |
+
penalty_alpha=penalty_alpha,
|
104 |
do_sample=do_sample,
|
105 |
min_length=200,
|
106 |
max_length=max_length,
|
|
|
111 |
num_return_sequences=1
|
112 |
)
|
113 |
result = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
|
114 |
+
# print(f"result: {result}")
|
115 |
prefix_length = len(title) + len(keywords) + len("title: keywords: ") + 2
|
116 |
result = result[prefix_length:]
|
117 |
return result
|
|
|
175 |
top_k = 30
|
176 |
top_p = 0.95
|
177 |
repetition_penalty = 0.0
|
178 |
+
penalty_alpha = None
|
179 |
|
180 |
if decoding_methods == "Beam Search":
|
181 |
do_sample = False
|
|
|
194 |
)
|
195 |
else:
|
196 |
do_sample = False
|
197 |
+
repetition_penalty = 1.1
|
198 |
penalty_alpha = st.sidebar.number_input(
|
199 |
"Penalty alpha",
|
200 |
value=0.6,
|
|
|
240 |
title=session_state.title,
|
241 |
keywords=session_state.keywords,
|
242 |
text=session_state.text, max_length=int(max_length),
|
243 |
+
temperature=temperature, do_sample=do_sample, penalty_alpha=penalty_alpha,
|
244 |
top_k=int(top_k), top_p=float(top_p), seed=seed, repetition_penalty=repetition_penalty)
|
245 |
time_end = time.time()
|
246 |
time_diff = time_end - time_start
|
247 |
# result = result[0]["generated_text"]
|
248 |
+
result = result[:result.find("title:")]
|
249 |
st.write(result.replace("\n", " \n"))
|
250 |
st.text("Translation")
|
251 |
translation = translate(result, "en", "id")
|