Spaces:
Sleeping
Sleeping
best-of-N exceed token issue
Browse files
app.py
CHANGED
@@ -161,7 +161,10 @@ def plan2align_translate_text(text, session_id, model, tokenizer, device, src_la
|
|
161 |
reward_model_type=reward_model_type,
|
162 |
session_id=session_id
|
163 |
)
|
164 |
-
|
|
|
|
|
|
|
165 |
return result, score
|
166 |
|
167 |
def evaluate_candidates(source, candidates, language, session_id):
|
@@ -178,21 +181,25 @@ def original_translation(text, src_language, target_language, session_id):
|
|
178 |
return "", 0
|
179 |
|
180 |
def best_of_n_translation(text, src_language, target_language, n, session_id):
|
181 |
-
if not check_token_length(text,
|
182 |
-
return "Warning: Input text
|
183 |
candidates = []
|
184 |
for i in range(n):
|
185 |
cand_list = basic_translate(text, src_language, target_language)
|
186 |
if cand_list:
|
187 |
candidates.append(cand_list[0])
|
188 |
-
|
189 |
-
|
190 |
-
|
|
|
|
|
|
|
|
|
191 |
return best, score
|
192 |
|
193 |
def mpc_translation(text, src_language, target_language, iterations, session_id):
|
194 |
-
if not check_token_length(text,
|
195 |
-
return "Warning: Input text
|
196 |
current_trans = ""
|
197 |
best_score = None
|
198 |
for i in range(iterations):
|
@@ -201,11 +208,17 @@ def mpc_translation(text, src_language, target_language, iterations, session_id)
|
|
201 |
else:
|
202 |
cand_list = mpc_improved_translate(text, current_trans, src_language, target_language)
|
203 |
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
return current_trans, best_score
|
210 |
|
211 |
# ---------- Gradio function ----------
|
@@ -240,8 +253,7 @@ def process_text(text, src_language, target_language, max_iterations_value, thre
|
|
240 |
)
|
241 |
plan2align_output = f"{plan2align_trans}\n\nScore: {best_score:.2f}"
|
242 |
if "Best-of-N" in translation_methods:
|
243 |
-
best_candidate, best_score = best_of_n_translation(text, src_language, target_language,
|
244 |
-
max_iterations_value, session_id)
|
245 |
best_of_n_output = f"{best_candidate}\n\nScore: {best_score:.2f}"
|
246 |
if "MPC" in translation_methods:
|
247 |
mpc_candidate, mpc_score = mpc_translation(text, src_language, target_language,
|
|
|
161 |
reward_model_type=reward_model_type,
|
162 |
session_id=session_id
|
163 |
)
|
164 |
+
try:
|
165 |
+
_, score = evaluate_candidates(text, [result], task_language, session_id)
|
166 |
+
except:
|
167 |
+
score = 0
|
168 |
return result, score
|
169 |
|
170 |
def evaluate_candidates(source, candidates, language, session_id):
|
|
|
181 |
return "", 0
|
182 |
|
183 |
def best_of_n_translation(text, src_language, target_language, n, session_id):
|
184 |
+
if not check_token_length(text, 4096):
|
185 |
+
return "Warning: Input text too long.", 0
|
186 |
candidates = []
|
187 |
for i in range(n):
|
188 |
cand_list = basic_translate(text, src_language, target_language)
|
189 |
if cand_list:
|
190 |
candidates.append(cand_list[0])
|
191 |
+
try:
|
192 |
+
best, score = evaluate_candidates(text, candidates, target_language, session_id)
|
193 |
+
print("best_of_n evaluate_candidates results:")
|
194 |
+
print(best, score)
|
195 |
+
except:
|
196 |
+
print("evaluate_candidates fail")
|
197 |
+
return "Warning: Input text too long.", 0
|
198 |
return best, score
|
199 |
|
200 |
def mpc_translation(text, src_language, target_language, iterations, session_id):
|
201 |
+
if not check_token_length(text, 4096):
|
202 |
+
return "Warning: Input text too long.", 0
|
203 |
current_trans = ""
|
204 |
best_score = None
|
205 |
for i in range(iterations):
|
|
|
208 |
else:
|
209 |
cand_list = mpc_improved_translate(text, current_trans, src_language, target_language)
|
210 |
|
211 |
+
try:
|
212 |
+
best, score = evaluate_candidates(text, cand_list, target_language, session_id)
|
213 |
+
print("mpc evaluate_candidates results:")
|
214 |
+
print(best, score)
|
215 |
+
current_trans = best
|
216 |
+
best_score = score
|
217 |
+
except:
|
218 |
+
print("evaluate_candidates fail")
|
219 |
+
current_trans = cand_list[0]
|
220 |
+
best_score = 0
|
221 |
+
|
222 |
return current_trans, best_score
|
223 |
|
224 |
# ---------- Gradio function ----------
|
|
|
253 |
)
|
254 |
plan2align_output = f"{plan2align_trans}\n\nScore: {best_score:.2f}"
|
255 |
if "Best-of-N" in translation_methods:
|
256 |
+
best_candidate, best_score = best_of_n_translation(text, src_language, target_language, max_iterations_value, session_id)
|
|
|
257 |
best_of_n_output = f"{best_candidate}\n\nScore: {best_score:.2f}"
|
258 |
if "MPC" in translation_methods:
|
259 |
mpc_candidate, mpc_score = mpc_translation(text, src_language, target_language,
|