cahya commited on
Commit
d081dd3
·
1 Parent(s): 23974da

fixed the tokenizer

Browse files
Files changed (1) hide show
  1. app/app.py +21 -17
app/app.py CHANGED
@@ -62,17 +62,8 @@ model_type = st.sidebar.selectbox('Model', (MODELS.keys()))
62
  @st.cache(suppress_st_warning=True, allow_output_mutation=True)
63
  def get_generator(model_name: str):
64
  st.write(f"Loading the GPT2 model {model_name}, please wait...")
65
- special_tokens = AbstractDataset.special_tokens
66
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_auth_token)
67
- tokenizer.add_special_tokens(special_tokens)
68
- config = AutoConfig.from_pretrained(model_name,
69
- bos_token_id=tokenizer.bos_token_id,
70
- eos_token_id=tokenizer.eos_token_id,
71
- sep_token_id=tokenizer.sep_token_id,
72
- pad_token_id=tokenizer.pad_token_id,
73
- output_hidden_states=False,
74
- use_auth_token=hf_auth_token)
75
- model = GPT2LMHeadModel.from_pretrained(model_name, config=config, use_auth_token=hf_auth_token)
76
  model.resize_token_embeddings(len(tokenizer))
77
  return model, tokenizer
78
 
@@ -81,24 +72,35 @@ def get_generator(model_name: str):
81
  # @st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: id})
82
  def process(text_generator, tokenizer, title: str, keywords: str, text: str,
83
  max_length: int = 200, do_sample: bool = True, top_k: int = 50, top_p: float = 0.95,
84
- temperature: float = 1.0, max_time: float = 120.0, seed=42, repetition_penalty=1.0):
 
85
  # st.write("Cache miss: process")
86
  set_seed(seed)
87
  if repetition_penalty == 0.0:
88
  min_penalty = 1.05
89
  max_penalty = 1.5
90
  repetition_penalty = max(min_penalty + (1.0 - temperature) * (max_penalty - min_penalty), 0.8)
91
- print("title:", title)
92
- print("keywords:", keywords)
93
  prompt = f"title: {title}\nkeywords: {keywords}\n{text}"
94
- print("prompt: ", prompt)
95
 
96
  generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
97
  # device = torch.device("cuda")
98
  # generated = generated.to(device)
99
 
 
 
 
 
 
 
 
 
 
100
  text_generator.eval()
101
  sample_outputs = text_generator.generate(generated,
 
102
  do_sample=do_sample,
103
  min_length=200,
104
  max_length=max_length,
@@ -109,7 +111,7 @@ def process(text_generator, tokenizer, title: str, keywords: str, text: str,
109
  num_return_sequences=1
110
  )
111
  result = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
112
- print(f"result: {result}")
113
  prefix_length = len(title) + len(keywords) + len("title: keywords: ") + 2
114
  result = result[prefix_length:]
115
  return result
@@ -173,6 +175,7 @@ if prompt_group_name in ["Indonesian Newspaper"]:
173
  top_k = 30
174
  top_p = 0.95
175
  repetition_penalty = 0.0
 
176
 
177
  if decoding_methods == "Beam Search":
178
  do_sample = False
@@ -191,7 +194,7 @@ if prompt_group_name in ["Indonesian Newspaper"]:
191
  )
192
  else:
193
  do_sample = False
194
- repetition_penalty = 1.0
195
  penalty_alpha = st.sidebar.number_input(
196
  "Penalty alpha",
197
  value=0.6,
@@ -237,11 +240,12 @@ if prompt_group_name in ["Indonesian Newspaper"]:
237
  title=session_state.title,
238
  keywords=session_state.keywords,
239
  text=session_state.text, max_length=int(max_length),
240
- temperature=temperature, do_sample=do_sample,
241
  top_k=int(top_k), top_p=float(top_p), seed=seed, repetition_penalty=repetition_penalty)
242
  time_end = time.time()
243
  time_diff = time_end - time_start
244
  # result = result[0]["generated_text"]
 
245
  st.write(result.replace("\n", " \n"))
246
  st.text("Translation")
247
  translation = translate(result, "en", "id")
 
62
  @st.cache(suppress_st_warning=True, allow_output_mutation=True)
63
  def get_generator(model_name: str):
64
  st.write(f"Loading the GPT2 model {model_name}, please wait...")
 
65
  tokenizer = AutoTokenizer.from_pretrained(model_name, use_auth_token=hf_auth_token)
66
+ model = GPT2LMHeadModel.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id, use_auth_token=hf_auth_token)
 
 
 
 
 
 
 
 
67
  model.resize_token_embeddings(len(tokenizer))
68
  return model, tokenizer
69
 
 
72
  # @st.cache(suppress_st_warning=True, hash_funcs={tokenizers.Tokenizer: id})
73
  def process(text_generator, tokenizer, title: str, keywords: str, text: str,
74
  max_length: int = 200, do_sample: bool = True, top_k: int = 50, top_p: float = 0.95,
75
+ temperature: float = 1.0, max_time: float = 120.0, seed=42, repetition_penalty=1.0,
76
+ penalty_alpha = 0.6):
77
  # st.write("Cache miss: process")
78
  set_seed(seed)
79
  if repetition_penalty == 0.0:
80
  min_penalty = 1.05
81
  max_penalty = 1.5
82
  repetition_penalty = max(min_penalty + (1.0 - temperature) * (max_penalty - min_penalty), 0.8)
83
+ # print("title:", title)
84
+ # print("keywords:", keywords)
85
  prompt = f"title: {title}\nkeywords: {keywords}\n{text}"
86
+ # print("prompt: ", prompt)
87
 
88
  generated = torch.tensor(tokenizer.encode(prompt)).unsqueeze(0)
89
  # device = torch.device("cuda")
90
  # generated = generated.to(device)
91
 
92
+ print("do_sample:", do_sample)
93
+ print("penalty_alpha:", penalty_alpha)
94
+ print("max_length:", max_length)
95
+ print("top_k:", top_k)
96
+ print("top_p:", top_p)
97
+ print("temperature:", temperature)
98
+ print("max_time:", max_time)
99
+ print("repetition_penalty:", repetition_penalty)
100
+
101
  text_generator.eval()
102
  sample_outputs = text_generator.generate(generated,
103
+ penalty_alpha=penalty_alpha,
104
  do_sample=do_sample,
105
  min_length=200,
106
  max_length=max_length,
 
111
  num_return_sequences=1
112
  )
113
  result = tokenizer.decode(sample_outputs[0], skip_special_tokens=True)
114
+ # print(f"result: {result}")
115
  prefix_length = len(title) + len(keywords) + len("title: keywords: ") + 2
116
  result = result[prefix_length:]
117
  return result
 
175
  top_k = 30
176
  top_p = 0.95
177
  repetition_penalty = 0.0
178
+ penalty_alpha = None
179
 
180
  if decoding_methods == "Beam Search":
181
  do_sample = False
 
194
  )
195
  else:
196
  do_sample = False
197
+ repetition_penalty = 1.1
198
  penalty_alpha = st.sidebar.number_input(
199
  "Penalty alpha",
200
  value=0.6,
 
240
  title=session_state.title,
241
  keywords=session_state.keywords,
242
  text=session_state.text, max_length=int(max_length),
243
+ temperature=temperature, do_sample=do_sample, penalty_alpha=penalty_alpha,
244
  top_k=int(top_k), top_p=float(top_p), seed=seed, repetition_penalty=repetition_penalty)
245
  time_end = time.time()
246
  time_diff = time_end - time_start
247
  # result = result[0]["generated_text"]
248
+ result = result[:result.find("title:")]
249
  st.write(result.replace("\n", " \n"))
250
  st.text("Translation")
251
  translation = translate(result, "en", "id")