cahya commited on
Commit
b77fd74
·
1 Parent(s): a771b16

use newspaper modmel

Browse files
Files changed (1) hide show
  1. app/app.py +58 -40
app/app.py CHANGED
@@ -10,18 +10,18 @@ import torch
10
  import os
11
  from abstract_dataset import AbstractDataset
12
 
13
-
14
  # st.set_page_config(page_title="Indonesian GPT-2")
15
 
16
- mirror_url = "https://abstract-generator.ai-research.id/"
17
  if "MIRROR_URL" in os.environ:
18
  mirror_url = os.environ["MIRROR_URL"]
 
19
 
20
  MODELS = {
21
- "Indonesian Academic Journal - Indonesian GPT-2 Medium": {
22
- "group": "Indonesian Journal",
23
- "name": "cahya/abstract-generator",
24
- "description": "Abstract Generator using Indonesian GPT-2 Medium.",
25
  "text_generator": None,
26
  "tokenizer": None
27
  },
@@ -85,7 +85,7 @@ def process(text_generator, tokenizer, title: str, keywords: str, text: str,
85
  if repetition_penalty == 0.0:
86
  min_penalty = 1.05
87
  max_penalty = 1.5
88
- repetition_penalty = max(min_penalty + (1.0-temperature) * (max_penalty-min_penalty), 0.8)
89
 
90
  keywords = [keyword.strip() for keyword in keywords.split(",")]
91
  keywords = AbstractDataset.join_keywords(keywords, randomize=False)
@@ -102,15 +102,16 @@ def process(text_generator, tokenizer, title: str, keywords: str, text: str,
102
 
103
  text_generator.eval()
104
  sample_outputs = text_generator.generate(generated,
105
- do_sample=do_sample,
106
- min_length=200,
107
- max_length=max_length,
108
- top_k=top_k,
109
- top_p=top_p,
110
- temperature=temperature,
111
- repetition_penalty=repetition_penalty,
112
- num_return_sequences=1
113
- )
 
114
  result = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
115
  print(f"result: {result}")
116
  prefix_length = len(title) + len(keywords)
@@ -127,9 +128,9 @@ model_name = f"Model name: [{MODELS[model_type]['name']}](https://huggingface.co
127
  st.markdown(model_name)
128
  if prompt_group_name in ["Indonesian GPT-2", "Indonesian Literature", "Indonesian Journal"]:
129
  session_state = SessionState.get(prompt=None, prompt_box=None, text=None)
130
- ALL_PROMPTS = list(PROMPT_LIST[prompt_group_name].keys())+["Custom"]
131
 
132
- prompt = st.selectbox('Prompt', ALL_PROMPTS, index=len(ALL_PROMPTS)-1)
133
 
134
  # Update prompt
135
  if session_state.prompt is None:
@@ -160,6 +161,12 @@ if prompt_group_name in ["Indonesian GPT-2", "Indonesian Literature", "Indonesia
160
  help="The maximum length of the sequence to be generated."
161
  )
162
 
 
 
 
 
 
 
163
  temperature = st.sidebar.slider(
164
  "Temperature",
165
  value=0.4,
@@ -167,15 +174,14 @@ if prompt_group_name in ["Indonesian GPT-2", "Indonesian Literature", "Indonesia
167
  max_value=2.0
168
  )
169
 
170
- do_sample = st.sidebar.checkbox(
171
- "Use sampling",
172
- value=True
173
- )
174
-
175
  top_k = 30
176
  top_p = 0.95
 
177
 
178
- if do_sample:
 
 
 
179
  top_k = st.sidebar.number_input(
180
  "Top k",
181
  value=top_k,
@@ -187,6 +193,19 @@ if prompt_group_name in ["Indonesian GPT-2", "Indonesian Literature", "Indonesia
187
  help="If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher "
188
  "are kept for generation."
189
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
  seed = st.sidebar.number_input(
192
  "Random Seed",
@@ -194,22 +213,21 @@ if prompt_group_name in ["Indonesian GPT-2", "Indonesian Literature", "Indonesia
194
  help="The number used to initialize a pseudorandom number generator"
195
  )
196
 
197
- repetition_penalty = 0.0
198
- automatic_repetition_penalty = st.sidebar.checkbox(
199
- "Automatic Repetition Penalty",
200
- value=True
201
- )
202
-
203
- if not automatic_repetition_penalty:
204
- repetition_penalty = st.sidebar.slider(
205
- "Repetition Penalty",
206
- value=1.0,
207
- min_value=1.0,
208
- max_value=2.0
209
  )
 
 
 
 
 
 
 
210
 
211
  for group_name in MODELS:
212
- if MODELS[group_name]["group"] in ["Indonesian GPT-2", "Indonesian Literature", "Indonesian Journal"]:
213
  MODELS[group_name]["text_generator"], MODELS[group_name]["tokenizer"] = \
214
  get_generator(MODELS[group_name]["name"])
215
 
@@ -226,15 +244,15 @@ if prompt_group_name in ["Indonesian GPT-2", "Indonesian Literature", "Indonesia
226
  temperature=temperature, do_sample=do_sample,
227
  top_k=int(top_k), top_p=float(top_p), seed=seed, repetition_penalty=repetition_penalty)
228
  time_end = time.time()
229
- time_diff = time_end-time_start
230
- #result = result[0]["generated_text"]
231
  st.write(result.replace("\n", " \n"))
232
  st.text("Translation")
233
  translation = translate(result, "en", "id")
234
  st.write(translation.replace("\n", " \n"))
235
  # st.write(f"*do_sample: {do_sample}, top_k: {top_k}, top_p: {top_p}, seed: {seed}*")
236
  info = f"""
237
- *Memory: {memory.total/(1024*1024*1024):.2f}GB, used: {memory.percent}%, available: {memory.available/(1024*1024*1024):.2f}GB*
238
  *Text generated in {time_diff:.5} seconds*
239
  """
240
  st.write(info)
 
10
  import os
11
  from abstract_dataset import AbstractDataset
12
 
 
13
  # st.set_page_config(page_title="Indonesian GPT-2")
14
 
15
+ mirror_url = "https://news-generator.ai-research.id/"
16
  if "MIRROR_URL" in os.environ:
17
  mirror_url = os.environ["MIRROR_URL"]
18
+ hf_auth_token = os.getenv("HF_AUTH_TOKEN", False)
19
 
20
  MODELS = {
21
+ "Indonesian Newspaper - Indonesian GPT-2 Medium": {
22
+ "group": "Indonesian Newspaper",
23
+ "name": "ai-research-id/gpt2-medium-newspaper",
24
+ "description": "Newspaper Generator using Indonesian GPT-2 Medium.",
25
  "text_generator": None,
26
  "tokenizer": None
27
  },
 
85
  if repetition_penalty == 0.0:
86
  min_penalty = 1.05
87
  max_penalty = 1.5
88
+ repetition_penalty = max(min_penalty + (1.0 - temperature) * (max_penalty - min_penalty), 0.8)
89
 
90
  keywords = [keyword.strip() for keyword in keywords.split(",")]
91
  keywords = AbstractDataset.join_keywords(keywords, randomize=False)
 
102
 
103
  text_generator.eval()
104
  sample_outputs = text_generator.generate(generated,
105
+ do_sample=do_sample,
106
+ min_length=200,
107
+ max_length=max_length,
108
+ top_k=top_k,
109
+ top_p=top_p,
110
+ temperature=temperature,
111
+ repetition_penalty=repetition_penalty,
112
+ num_return_sequences=1,
113
+ hf_auth_token=hf_auth_token
114
+ )
115
  result = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
116
  print(f"result: {result}")
117
  prefix_length = len(title) + len(keywords)
 
128
  st.markdown(model_name)
129
  if prompt_group_name in ["Indonesian GPT-2", "Indonesian Literature", "Indonesian Journal"]:
130
  session_state = SessionState.get(prompt=None, prompt_box=None, text=None)
131
+ ALL_PROMPTS = list(PROMPT_LIST[prompt_group_name].keys()) + ["Custom"]
132
 
133
+ prompt = st.selectbox('Prompt', ALL_PROMPTS, index=len(ALL_PROMPTS) - 1)
134
 
135
  # Update prompt
136
  if session_state.prompt is None:
 
161
  help="The maximum length of the sequence to be generated."
162
  )
163
 
164
+ decoding_methods = st.sidebar.radio(
165
+ "Set the decoding methods:",
166
+ key="decoding",
167
+ options=["Beam Search", "Sampling", "Contrastive Search"],
168
+ )
169
+
170
  temperature = st.sidebar.slider(
171
  "Temperature",
172
  value=0.4,
 
174
  max_value=2.0
175
  )
176
 
 
 
 
 
 
177
  top_k = 30
178
  top_p = 0.95
179
+ repetition_penalty = 0.0
180
 
181
+ if decoding_methods == "Beam Search":
182
+ do_sample = False
183
+ elif decoding_methods == "Sampling":
184
+ do_sample = True
185
  top_k = st.sidebar.number_input(
186
  "Top k",
187
  value=top_k,
 
193
  help="If set to float < 1, only the most probable tokens with probabilities that add up to top_p or higher "
194
  "are kept for generation."
195
  )
196
+ else:
197
+ do_sample = False
198
+ repetition_penalty = 1.0
199
+ penalty_alpha = st.sidebar.number_input(
200
+ "Penalty alpha",
201
+ value=0.6,
202
+ help="The penalty alpha for contrastive search."
203
+ )
204
+ top_k = st.sidebar.number_input(
205
+ "Top k",
206
+ value=4,
207
+ help="The number of highest probability vocabulary tokens to keep for top-k-filtering."
208
+ )
209
 
210
  seed = st.sidebar.number_input(
211
  "Random Seed",
 
213
  help="The number used to initialize a pseudorandom number generator"
214
  )
215
 
216
+ if decoding_methods != "Contrastive Search":
217
+ automatic_repetition_penalty = st.sidebar.checkbox(
218
+ "Automatic Repetition Penalty",
219
+ value=True
 
 
 
 
 
 
 
 
220
  )
221
+ if not automatic_repetition_penalty:
222
+ repetition_penalty = st.sidebar.slider(
223
+ "Repetition Penalty",
224
+ value=1.0,
225
+ min_value=1.0,
226
+ max_value=2.0
227
+ )
228
 
229
  for group_name in MODELS:
230
+ if MODELS[group_name]["group"] in ["Indonesian Newspaper"]:
231
  MODELS[group_name]["text_generator"], MODELS[group_name]["tokenizer"] = \
232
  get_generator(MODELS[group_name]["name"])
233
 
 
244
  temperature=temperature, do_sample=do_sample,
245
  top_k=int(top_k), top_p=float(top_p), seed=seed, repetition_penalty=repetition_penalty)
246
  time_end = time.time()
247
+ time_diff = time_end - time_start
248
+ # result = result[0]["generated_text"]
249
  st.write(result.replace("\n", " \n"))
250
  st.text("Translation")
251
  translation = translate(result, "en", "id")
252
  st.write(translation.replace("\n", " \n"))
253
  # st.write(f"*do_sample: {do_sample}, top_k: {top_k}, top_p: {top_p}, seed: {seed}*")
254
  info = f"""
255
+ *Memory: {memory.total / (1024 * 1024 * 1024):.2f}GB, used: {memory.percent}%, available: {memory.available / (1024 * 1024 * 1024):.2f}GB*
256
  *Text generated in {time_diff:.5} seconds*
257
  """
258
  st.write(info)