cahya commited on
Commit
a0df823
·
1 Parent(s): 8f794cd

remove print out hf_

Browse files
Files changed (1) hide show
  1. app/app.py +3 -16
app/app.py CHANGED
@@ -16,7 +16,6 @@ mirror_url = "https://news-generator.ai-research.id/"
16
  if "MIRROR_URL" in os.environ:
17
  mirror_url = os.environ["MIRROR_URL"]
18
  hf_auth_token = os.getenv("HF_AUTH_TOKEN", False)
19
- st.write(f"Using Hugging Face auth token: {hf_auth_token[:10]}...")
20
 
21
  MODELS = {
22
  "Indonesian Newspaper - Indonesian GPT-2 Medium": {
@@ -80,24 +79,12 @@ def process(text_generator, tokenizer, title: str, keywords: str, text: str,
80
  min_penalty = 1.05
81
  max_penalty = 1.5
82
  repetition_penalty = max(min_penalty + (1.0 - temperature) * (max_penalty - min_penalty), 0.8)
83
- # print("title:", title)
84
- # print("keywords:", keywords)
85
  prompt = f"title: {title}\nkeywords: {keywords}\n{text}"
86
- # print("prompt: ", prompt)
87
 
88
  generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
89
  # device = torch.device("cuda")
90
  # generated = generated.to(device)
91
 
92
- print("do_sample:", do_sample)
93
- print("penalty_alpha:", penalty_alpha)
94
- print("max_length:", max_length)
95
- print("top_k:", top_k)
96
- print("top_p:", top_p)
97
- print("temperature:", temperature)
98
- print("max_time:", max_time)
99
- print("repetition_penalty:", repetition_penalty)
100
-
101
  text_generator.eval()
102
  sample_outputs = text_generator.generate(generated,
103
  penalty_alpha=penalty_alpha,
@@ -111,16 +98,17 @@ def process(text_generator, tokenizer, title: str, keywords: str, text: str,
111
  num_return_sequences=1
112
  )
113
  result = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
114
- # print(f"result: {result}")
115
  prefix_length = len(title) + len(keywords) + len("title: keywords: ") + 2
116
  result = result[prefix_length:]
 
 
117
  return result
118
 
119
 
120
  st.title("Indonesian GPT-2 Applications")
121
  prompt_group_name = MODELS[model_type]["group"]
122
  st.header(prompt_group_name)
123
- description = f"This is a bilingual (Indonesian and English) abstract generator using Indonesian GPT-2 Medium. We finetuned it with the Indonesian paper abstract dataset."
124
  st.markdown(description)
125
  model_name = f"Model name: [{MODELS[model_type]['name']}](https://huggingface.co/{MODELS[model_type]['name']})"
126
  st.markdown(model_name)
@@ -246,7 +234,6 @@ if prompt_group_name in ["Indonesian Newspaper"]:
246
  time_end = time.time()
247
  time_diff = time_end - time_start
248
  # result = result[0]["generated_text"]
249
- result = result[:result.find("title:")]
250
  st.write(result.replace("\n", " \n"))
251
  st.text("Translation")
252
  translation = translate(result, "en", "id")
 
16
  if "MIRROR_URL" in os.environ:
17
  mirror_url = os.environ["MIRROR_URL"]
18
  hf_auth_token = os.getenv("HF_AUTH_TOKEN", False)
 
19
 
20
  MODELS = {
21
  "Indonesian Newspaper - Indonesian GPT-2 Medium": {
 
79
  min_penalty = 1.05
80
  max_penalty = 1.5
81
  repetition_penalty = max(min_penalty + (1.0 - temperature) * (max_penalty - min_penalty), 0.8)
 
 
82
  prompt = f"title: {title}\nkeywords: {keywords}\n{text}"
 
83
 
84
  generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
85
  # device = torch.device("cuda")
86
  # generated = generated.to(device)
87
 
 
 
 
 
 
 
 
 
 
88
  text_generator.eval()
89
  sample_outputs = text_generator.generate(generated,
90
  penalty_alpha=penalty_alpha,
 
98
  num_return_sequences=1
99
  )
100
  result = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
 
101
  prefix_length = len(title) + len(keywords) + len("title: keywords: ") + 2
102
  result = result[prefix_length:]
103
+ title_index = result.find("title: ")
104
+ result = result[:title_index] if title_index > 0 else result
105
  return result
106
 
107
 
108
  st.title("Indonesian GPT-2 Applications")
109
  prompt_group_name = MODELS[model_type]["group"]
110
  st.header(prompt_group_name)
111
+ description = f"This is a news generator using Indonesian GPT-2 Medium. We finetuned the pre-trained model with the Indonesian online newspaper dataset."
112
  st.markdown(description)
113
  model_name = f"Model name: [{MODELS[model_type]['name']}](https://huggingface.co/{MODELS[model_type]['name']})"
114
  st.markdown(model_name)
 
234
  time_end = time.time()
235
  time_diff = time_end - time_start
236
  # result = result[0]["generated_text"]
 
237
  st.write(result.replace("\n", " \n"))
238
  st.text("Translation")
239
  translation = translate(result, "en", "id")