KuangDW commited on
Commit
5e39340
·
1 Parent(s): ba3dd45

best-of-N exceed token issue

Browse files
Files changed (1) hide show
  1. app.py +27 -15
app.py CHANGED
@@ -161,7 +161,10 @@ def plan2align_translate_text(text, session_id, model, tokenizer, device, src_la
161
  reward_model_type=reward_model_type,
162
  session_id=session_id
163
  )
164
- _, score = evaluate_candidates(text, [result], task_language, session_id)
 
 
 
165
  return result, score
166
 
167
  def evaluate_candidates(source, candidates, language, session_id):
@@ -178,21 +181,25 @@ def original_translation(text, src_language, target_language, session_id):
178
  return "", 0
179
 
180
  def best_of_n_translation(text, src_language, target_language, n, session_id):
181
- if not check_token_length(text, 2048):
182
- return "Warning: Input text exceeds 2048 tokens.", None, ""
183
  candidates = []
184
  for i in range(n):
185
  cand_list = basic_translate(text, src_language, target_language)
186
  if cand_list:
187
  candidates.append(cand_list[0])
188
- best, score = evaluate_candidates(text, candidates, target_language, session_id)
189
- print("best_of_n evaluate_candidates results:")
190
- print(best, score)
 
 
 
 
191
  return best, score
192
 
193
  def mpc_translation(text, src_language, target_language, iterations, session_id):
194
- if not check_token_length(text, 2048):
195
- return "Warning: Input text exceeds 2048 tokens.", None, ""
196
  current_trans = ""
197
  best_score = None
198
  for i in range(iterations):
@@ -201,11 +208,17 @@ def mpc_translation(text, src_language, target_language, iterations, session_id)
201
  else:
202
  cand_list = mpc_improved_translate(text, current_trans, src_language, target_language)
203
 
204
- best, score = evaluate_candidates(text, cand_list, target_language, session_id)
205
- print("mpc evaluate_candidates results:")
206
- print(best, score)
207
- current_trans = best
208
- best_score = score
 
 
 
 
 
 
209
  return current_trans, best_score
210
 
211
  # ---------- Gradio function ----------
@@ -240,8 +253,7 @@ def process_text(text, src_language, target_language, max_iterations_value, thre
240
  )
241
  plan2align_output = f"{plan2align_trans}\n\nScore: {best_score:.2f}"
242
  if "Best-of-N" in translation_methods:
243
- best_candidate, best_score = best_of_n_translation(text, src_language, target_language,
244
- max_iterations_value, session_id)
245
  best_of_n_output = f"{best_candidate}\n\nScore: {best_score:.2f}"
246
  if "MPC" in translation_methods:
247
  mpc_candidate, mpc_score = mpc_translation(text, src_language, target_language,
 
161
  reward_model_type=reward_model_type,
162
  session_id=session_id
163
  )
164
+ try:
165
+ _, score = evaluate_candidates(text, [result], task_language, session_id)
166
+ except:
167
+ score = 0
168
  return result, score
169
 
170
  def evaluate_candidates(source, candidates, language, session_id):
 
181
  return "", 0
182
 
183
  def best_of_n_translation(text, src_language, target_language, n, session_id):
184
+ if not check_token_length(text, 4096):
185
+ return "Warning: Input text too long.", 0
186
  candidates = []
187
  for i in range(n):
188
  cand_list = basic_translate(text, src_language, target_language)
189
  if cand_list:
190
  candidates.append(cand_list[0])
191
+ try:
192
+ best, score = evaluate_candidates(text, candidates, target_language, session_id)
193
+ print("best_of_n evaluate_candidates results:")
194
+ print(best, score)
195
+ except:
196
+ print("evaluate_candidates fail")
197
+ return "Warning: Input text too long.", 0
198
  return best, score
199
 
200
  def mpc_translation(text, src_language, target_language, iterations, session_id):
201
+ if not check_token_length(text, 4096):
202
+ return "Warning: Input text too long.", 0
203
  current_trans = ""
204
  best_score = None
205
  for i in range(iterations):
 
208
  else:
209
  cand_list = mpc_improved_translate(text, current_trans, src_language, target_language)
210
 
211
+ try:
212
+ best, score = evaluate_candidates(text, cand_list, target_language, session_id)
213
+ print("mpc evaluate_candidates results:")
214
+ print(best, score)
215
+ current_trans = best
216
+ best_score = score
217
+ except:
218
+ print("evaluate_candidates fail")
219
+ current_trans = cand_list[0]
220
+ best_score = 0
221
+
222
  return current_trans, best_score
223
 
224
  # ---------- Gradio function ----------
 
253
  )
254
  plan2align_output = f"{plan2align_trans}\n\nScore: {best_score:.2f}"
255
  if "Best-of-N" in translation_methods:
256
+ best_candidate, best_score = best_of_n_translation(text, src_language, target_language, max_iterations_value, session_id)
 
257
  best_of_n_output = f"{best_candidate}\n\nScore: {best_score:.2f}"
258
  if "MPC" in translation_methods:
259
  mpc_candidate, mpc_score = mpc_translation(text, src_language, target_language,