cahya commited on
Commit
23974da
·
1 Parent(s): 169acf1
Files changed (1) hide show
  1. app/app.py +6 -11
app/app.py CHANGED
@@ -88,15 +88,10 @@ def process(text_generator, tokenizer, title: str, keywords: str, text: str,
88
  min_penalty = 1.05
89
  max_penalty = 1.5
90
  repetition_penalty = max(min_penalty + (1.0 - temperature) * (max_penalty - min_penalty), 0.8)
91
-
92
- keywords = [keyword.strip() for keyword in keywords.split(",")]
93
- keywords = AbstractDataset.join_keywords(keywords, randomize=False)
94
-
95
- special_tokens = AbstractDataset.special_tokens
96
- prompt = special_tokens['bos_token'] + title + \
97
- special_tokens['sep_token'] + keywords + special_tokens['sep_token'] + text
98
-
99
- print(f"title: {title}, keywords: {keywords}, text: {text}")
100
 
101
  generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
102
  # device = torch.device("cuda")
@@ -115,7 +110,7 @@ def process(text_generator, tokenizer, title: str, keywords: str, text: str,
115
  )
116
  result = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
117
  print(f"result: {result}")
118
- prefix_length = len(title) + len(keywords)
119
  result = result[prefix_length:]
120
  return result
121
 
@@ -231,7 +226,7 @@ if prompt_group_name in ["Indonesian Newspaper"]:
231
  if MODELS[group_name]["group"] in ["Indonesian Newspaper"]:
232
  MODELS[group_name]["text_generator"], MODELS[group_name]["tokenizer"] = \
233
  get_generator(MODELS[group_name]["name"])
234
- st.write(f"Generator: {MODELS}'")
235
  if st.button("Run"):
236
  with st.spinner(text="Getting results..."):
237
  memory = psutil.virtual_memory()
 
88
  min_penalty = 1.05
89
  max_penalty = 1.5
90
  repetition_penalty = max(min_penalty + (1.0 - temperature) * (max_penalty - min_penalty), 0.8)
91
+ print("title:", title)
92
+ print("keywords:", keywords)
93
+ prompt = f"title: {title}\nkeywords: {keywords}\n{text}"
94
+ print("prompt: ", prompt)
 
 
 
 
 
95
 
96
  generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
97
  # device = torch.device("cuda")
 
110
  )
111
  result = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
112
  print(f"result: {result}")
113
+ prefix_length = len(title) + len(keywords) + len("title: keywords: ") + 2
114
  result = result[prefix_length:]
115
  return result
116
 
 
226
  if MODELS[group_name]["group"] in ["Indonesian Newspaper"]:
227
  MODELS[group_name]["text_generator"], MODELS[group_name]["tokenizer"] = \
228
  get_generator(MODELS[group_name]["name"])
229
+ # st.write(f"Generator: {MODELS}'")
230
  if st.button("Run"):
231
  with st.spinner(text="Getting results..."):
232
  memory = psutil.virtual_memory()