Spaces:
Runtime error
Runtime error
use newspaper modmel
Browse files- app/app.py +58 -40
app/app.py
CHANGED
@@ -10,18 +10,18 @@ import torch
|
|
10 |
import os
|
11 |
from abstract_dataset import AbstractDataset
|
12 |
|
13 |
-
|
14 |
# st.set_page_config(page_title="Indonesian GPT-2")
|
15 |
|
16 |
-
mirror_url = "https://
|
17 |
if "MIRROR_URL" in os.environ:
|
18 |
mirror_url = os.environ["MIRROR_URL"]
|
|
|
19 |
|
20 |
MODELS = {
|
21 |
-
"Indonesian
|
22 |
-
"group": "Indonesian
|
23 |
-
"name": "
|
24 |
-
"description": "
|
25 |
"text_generator": None,
|
26 |
"tokenizer": None
|
27 |
},
|
@@ -85,7 +85,7 @@ def process(text_generator, tokenizer, title: str, keywords: str, text: str,
|
|
85 |
if repetition_penalty == 0.0:
|
86 |
min_penalty = 1.05
|
87 |
max_penalty = 1.5
|
88 |
-
repetition_penalty = max(min_penalty + (1.0-temperature) * (max_penalty-min_penalty), 0.8)
|
89 |
|
90 |
keywords = [keyword.strip() for keyword in keywords.split(",")]
|
91 |
keywords = AbstractDataset.join_keywords(keywords, randomize=False)
|
@@ -102,15 +102,16 @@ def process(text_generator, tokenizer, title: str, keywords: str, text: str,
|
|
102 |
|
103 |
text_generator.eval()
|
104 |
sample_outputs = text_generator.generate(generated,
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
|
|
114 |
result = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
|
115 |
print(f"result: {result}")
|
116 |
prefix_length = len(title) + len(keywords)
|
@@ -127,9 +128,9 @@ model_name = f"Model name: [{MODELS[model_type]['name']}](https://huggingface.co
|
|
127 |
st.markdown(model_name)
|
128 |
if prompt_group_name in ["Indonesian GPT-2", "Indonesian Literature", "Indonesian Journal"]:
|
129 |
session_state = SessionState.get(prompt=None, prompt_box=None, text=None)
|
130 |
-
ALL_PROMPTS = list(PROMPT_LIST[prompt_group_name].keys())+["Custom"]
|
131 |
|
132 |
-
prompt = st.selectbox('Prompt', ALL_PROMPTS, index=len(ALL_PROMPTS)-1)
|
133 |
|
134 |
# Update prompt
|
135 |
if session_state.prompt is None:
|
@@ -160,6 +161,12 @@ if prompt_group_name in ["Indonesian GPT-2", "Indonesian Literature", "Indonesia
|
|
160 |
help="The maximum length of the sequence to be generated."
|
161 |
)
|
162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
temperature = st.sidebar.slider(
|
164 |
"Temperature",
|
165 |
value=0.4,
|
@@ -167,15 +174,14 @@ if prompt_group_name in ["Indonesian GPT-2", "Indonesian Literature", "Indonesia
|
|
167 |
max_value=2.0
|
168 |
)
|
169 |
|
170 |
-
do_sample = st.sidebar.checkbox(
|
171 |
-
"Use sampling",
|
172 |
-
value=True
|
173 |
-
)
|
174 |
-
|
175 |
top_k = 30
|
176 |
top_p = 0.95
|
|
|
177 |
|
178 |
-
if
|
|
|
|
|
|
|
179 |
top_k = st.sidebar.number_input(
|
180 |
"Top k",
|
181 |
value=top_k,
|
@@ -187,6 +193,19 @@ if prompt_group_name in ["Indonesian GPT-2", "Indonesian Literature", "Indonesia
|
|
187 |
help="If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher "
|
188 |
"are kept for generation."
|
189 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
|
191 |
seed = st.sidebar.number_input(
|
192 |
"Random Seed",
|
@@ -194,22 +213,21 @@ if prompt_group_name in ["Indonesian GPT-2", "Indonesian Literature", "Indonesia
|
|
194 |
help="The number used to initialize a pseudorandom number generator"
|
195 |
)
|
196 |
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
)
|
202 |
-
|
203 |
-
if not automatic_repetition_penalty:
|
204 |
-
repetition_penalty = st.sidebar.slider(
|
205 |
-
"Repetition Penalty",
|
206 |
-
value=1.0,
|
207 |
-
min_value=1.0,
|
208 |
-
max_value=2.0
|
209 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
|
211 |
for group_name in MODELS:
|
212 |
-
if MODELS[group_name]["group"] in ["Indonesian
|
213 |
MODELS[group_name]["text_generator"], MODELS[group_name]["tokenizer"] = \
|
214 |
get_generator(MODELS[group_name]["name"])
|
215 |
|
@@ -226,15 +244,15 @@ if prompt_group_name in ["Indonesian GPT-2", "Indonesian Literature", "Indonesia
|
|
226 |
temperature=temperature, do_sample=do_sample,
|
227 |
top_k=int(top_k), top_p=float(top_p), seed=seed, repetition_penalty=repetition_penalty)
|
228 |
time_end = time.time()
|
229 |
-
time_diff = time_end-time_start
|
230 |
-
#result = result[0]["generated_text"]
|
231 |
st.write(result.replace("\n", " \n"))
|
232 |
st.text("Translation")
|
233 |
translation = translate(result, "en", "id")
|
234 |
st.write(translation.replace("\n", " \n"))
|
235 |
# st.write(f"*do_sample: {do_sample}, top_k: {top_k}, top_p: {top_p}, seed: {seed}*")
|
236 |
info = f"""
|
237 |
-
*Memory: {memory.total/(1024*1024*1024):.2f}GB, used: {memory.percent}%, available: {memory.available/(1024*1024*1024):.2f}GB*
|
238 |
*Text generated in {time_diff:.5} seconds*
|
239 |
"""
|
240 |
st.write(info)
|
|
|
10 |
import os
|
11 |
from abstract_dataset import AbstractDataset
|
12 |
|
|
|
13 |
# st.set_page_config(page_title="Indonesian GPT-2")
|
14 |
|
15 |
+
mirror_url = "https://news-generator.ai-research.id/"
|
16 |
if "MIRROR_URL" in os.environ:
|
17 |
mirror_url = os.environ["MIRROR_URL"]
|
18 |
+
hf_auth_token = os.getenv("HF_AUTH_TOKEN", False)
|
19 |
|
20 |
MODELS = {
|
21 |
+
"Indonesian Newspaper - Indonesian GPT-2 Medium": {
|
22 |
+
"group": "Indonesian Newspaper",
|
23 |
+
"name": "ai-research-id/gpt2-medium-newspaper",
|
24 |
+
"description": "Newspaper Generator using Indonesian GPT-2 Medium.",
|
25 |
"text_generator": None,
|
26 |
"tokenizer": None
|
27 |
},
|
|
|
85 |
if repetition_penalty == 0.0:
|
86 |
min_penalty = 1.05
|
87 |
max_penalty = 1.5
|
88 |
+
repetition_penalty = max(min_penalty + (1.0 - temperature) * (max_penalty - min_penalty), 0.8)
|
89 |
|
90 |
keywords = [keyword.strip() for keyword in keywords.split(",")]
|
91 |
keywords = AbstractDataset.join_keywords(keywords, randomize=False)
|
|
|
102 |
|
103 |
text_generator.eval()
|
104 |
sample_outputs = text_generator.generate(generated,
|
105 |
+
do_sample=do_sample,
|
106 |
+
min_length=200,
|
107 |
+
max_length=max_length,
|
108 |
+
top_k=top_k,
|
109 |
+
top_p=top_p,
|
110 |
+
temperature=temperature,
|
111 |
+
repetition_penalty=repetition_penalty,
|
112 |
+
num_return_sequences=1,
|
113 |
+
hf_auth_token=hf_auth_token
|
114 |
+
)
|
115 |
result = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
|
116 |
print(f"result: {result}")
|
117 |
prefix_length = len(title) + len(keywords)
|
|
|
128 |
st.markdown(model_name)
|
129 |
if prompt_group_name in ["Indonesian GPT-2", "Indonesian Literature", "Indonesian Journal"]:
|
130 |
session_state = SessionState.get(prompt=None, prompt_box=None, text=None)
|
131 |
+
ALL_PROMPTS = list(PROMPT_LIST[prompt_group_name].keys()) + ["Custom"]
|
132 |
|
133 |
+
prompt = st.selectbox('Prompt', ALL_PROMPTS, index=len(ALL_PROMPTS) - 1)
|
134 |
|
135 |
# Update prompt
|
136 |
if session_state.prompt is None:
|
|
|
161 |
help="The maximum length of the sequence to be generated."
|
162 |
)
|
163 |
|
164 |
+
decoding_methods = st.sidebar.radio(
|
165 |
+
"Set the decoding methods:",
|
166 |
+
key="decoding",
|
167 |
+
options=["Beam Search", "Sampling", "Contrastive Search"],
|
168 |
+
)
|
169 |
+
|
170 |
temperature = st.sidebar.slider(
|
171 |
"Temperature",
|
172 |
value=0.4,
|
|
|
174 |
max_value=2.0
|
175 |
)
|
176 |
|
|
|
|
|
|
|
|
|
|
|
177 |
top_k = 30
|
178 |
top_p = 0.95
|
179 |
+
repetition_penalty = 0.0
|
180 |
|
181 |
+
if decoding_methods == "Beam Search":
|
182 |
+
do_sample = False
|
183 |
+
elif decoding_methods == "Sampling":
|
184 |
+
do_sample = True
|
185 |
top_k = st.sidebar.number_input(
|
186 |
"Top k",
|
187 |
value=top_k,
|
|
|
193 |
help="If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher "
|
194 |
"are kept for generation."
|
195 |
)
|
196 |
+
else:
|
197 |
+
do_sample = False
|
198 |
+
repetition_penalty = 1.0
|
199 |
+
penalty_alpha = st.sidebar.number_input(
|
200 |
+
"Penalty alpha",
|
201 |
+
value=0.6,
|
202 |
+
help="The penalty alpha for contrastive search."
|
203 |
+
)
|
204 |
+
top_k = st.sidebar.number_input(
|
205 |
+
"Top k",
|
206 |
+
value=4,
|
207 |
+
help="The number of highest probability vocabulary tokens to keep for top-k-filtering."
|
208 |
+
)
|
209 |
|
210 |
seed = st.sidebar.number_input(
|
211 |
"Random Seed",
|
|
|
213 |
help="The number used to initialize a pseudorandom number generator"
|
214 |
)
|
215 |
|
216 |
+
if decoding_methods != "Contrastive Search":
|
217 |
+
automatic_repetition_penalty = st.sidebar.checkbox(
|
218 |
+
"Automatic Repetition Penalty",
|
219 |
+
value=True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
)
|
221 |
+
if not automatic_repetition_penalty:
|
222 |
+
repetition_penalty = st.sidebar.slider(
|
223 |
+
"Repetition Penalty",
|
224 |
+
value=1.0,
|
225 |
+
min_value=1.0,
|
226 |
+
max_value=2.0
|
227 |
+
)
|
228 |
|
229 |
for group_name in MODELS:
|
230 |
+
if MODELS[group_name]["group"] in ["Indonesian Newspaper"]:
|
231 |
MODELS[group_name]["text_generator"], MODELS[group_name]["tokenizer"] = \
|
232 |
get_generator(MODELS[group_name]["name"])
|
233 |
|
|
|
244 |
temperature=temperature, do_sample=do_sample,
|
245 |
top_k=int(top_k), top_p=float(top_p), seed=seed, repetition_penalty=repetition_penalty)
|
246 |
time_end = time.time()
|
247 |
+
time_diff = time_end - time_start
|
248 |
+
# result = result[0]["generated_text"]
|
249 |
st.write(result.replace("\n", " \n"))
|
250 |
st.text("Translation")
|
251 |
translation = translate(result, "en", "id")
|
252 |
st.write(translation.replace("\n", " \n"))
|
253 |
# st.write(f"*do_sample: {do_sample}, top_k: {top_k}, top_p: {top_p}, seed: {seed}*")
|
254 |
info = f"""
|
255 |
+
*Memory: {memory.total / (1024 * 1024 * 1024):.2f}GB, used: {memory.percent}%, available: {memory.available / (1024 * 1024 * 1024):.2f}GB*
|
256 |
*Text generated in {time_diff:.5} seconds*
|
257 |
"""
|
258 |
st.write(info)
|