KuangDW commited on
Commit
18ff227
·
1 Parent(s): 87d5a16

add chunk function

Browse files
Files changed (1) hide show
  1. app.py +115 -38
app.py CHANGED
@@ -69,21 +69,33 @@ model = AutoModelForCausalLM.from_pretrained(
69
  torch_dtype=torch.float16
70
  )
71
 
72
- # def generate_translation(system_prompt, prompt):
73
- # messages=[
74
- # {"role": "system", "content": system_prompt},
75
- # {"role": "user", "content": prompt}
76
- # ]
77
- # inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)
78
- # outputs = model.generate(
79
- # inputs,
80
- # max_new_tokens=512,
81
- # temperature=0.7,
82
- # top_p=0.9,
83
- # do_sample=True
84
- # )
85
- # translation = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
86
- # return translation
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  def generate_translation(system_prompt, prompt):
89
  full_prompt = f"System: {system_prompt}\nUser: {prompt}\nAssistant:"
@@ -160,6 +172,21 @@ def basic_translate(source_sentence, src_language, tgt_language):
160
  translations.append(translation)
161
  return translations
162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
163
  def plan2align_translate_text(text, session_id, model, tokenizer, device, src_language, task_language, max_iterations_value, threshold_value, good_ref_contexts_num_value, reward_model_type):
164
  result = translate_text(
165
  text = text,
@@ -255,23 +282,67 @@ def process_text(text, src_language, target_language, max_iterations_value, thre
255
  best_of_n_output = ""
256
  mpc_output = ""
257
 
258
-
259
- if "Original" in translation_methods:
260
- orig, best_score = original_translation(text, src_language, target_language, session_id)
261
- orig_output = f"{orig}\n\nScore: {best_score:.2f}"
262
- if "Plan2Align" in translation_methods:
263
- plan2align_trans, best_score = plan2align_translate_text(
264
- text, session_id, model, tokenizer, device, src_language, target_language,
265
- max_iterations_value, threshold_value, good_ref_contexts_num_value, "metricx"
266
- )
267
- plan2align_output = f"{plan2align_trans}\n\nScore: {best_score:.2f}"
268
- if "Best-of-N" in translation_methods:
269
- best_candidate, best_score = best_of_n_translation(text, src_language, target_language, max_iterations_value, session_id)
270
- best_of_n_output = f"{best_candidate}\n\nScore: {best_score:.2f}"
271
- if "MPC" in translation_methods:
272
- mpc_candidate, mpc_score = mpc_translation(text, src_language, target_language,
273
- max_iterations_value, session_id)
274
- mpc_output = f"{mpc_candidate}\n\nScore: {mpc_score:.2f}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
 
276
  return orig_output, plan2align_output, best_of_n_output, mpc_output
277
 
@@ -310,6 +381,10 @@ with gr.Blocks(title="Test-Time Machine Translation with Plan2Align") as demo:
310
  value=["Original", "Plan2Align"],
311
  label="Translation Methods"
312
  )
 
 
 
 
313
  translate_button = gr.Button("Translate")
314
  with gr.Column(scale=2):
315
  original_output = gr.Textbox(
@@ -343,6 +418,7 @@ with gr.Blocks(title="Test-Time Machine Translation with Plan2Align") as demo:
343
  threshold_input,
344
  good_ref_contexts_num_input,
345
  translation_methods_input,
 
346
  state
347
  ],
348
  outputs=[original_output, plan2align_output, best_of_n_output, mpc_output]
@@ -350,11 +426,11 @@ with gr.Blocks(title="Test-Time Machine Translation with Plan2Align") as demo:
350
 
351
  gr.Examples(
352
  examples=[
353
- ["台灣夜市文化豐富多彩,從士林夜市到饒河街夜市,提供各種美食、遊戲和購物體驗,吸引了無數遊客。", "Traditional Chinese", "English", 2, 0.7, 1, ["Original", "Plan2Align"]],
354
- ["台北101曾經是世界最高的建築物,它不僅是台灣的地標,也象徵著經濟成就和創新精神。", "Traditional Chinese", "Russian", 2, 0.7, 1, ["Original", "Plan2Align"]],
355
- ["阿里山日出和森林鐵路是台灣最著名的自然景觀之一,每年吸引數十萬遊客前來欣賞雲海和壯麗的日出。", "Traditional Chinese", "German", 2, 0.7, 1, ["Original", "Plan2Align"]],
356
- ["珍珠奶茶,這款源自台灣的獨特飲品,不僅在台灣本地深受喜愛,更以其獨特的風味和口感,在全球掀起了一股熱潮,成為了一種跨越文化、風靡全球的時尚飲品。", "Traditional Chinese", "Japanese", 3, 0.7, 3, ["Original", "Plan2Align"]],
357
- ["原住民文化如同一片深邃的星空,閃爍著無數璀璨的傳統與藝術光芒。他們的歌舞,是與祖靈對話的旋律,是與自然共鳴的節奏,每一個舞步、每一聲吟唱,都承載著古老的傳說與智慧。編織,是他們巧手下的詩篇,一絲一線,交織出生命的紋理,也編織出對土地的熱愛與敬畏。木雕,則是他們與自然對話的雕塑,每一刀、每一鑿,都刻畫著對萬物的觀察與敬意,也雕琢出對祖先的追憶與傳承。", "Traditional Chinese", "Korean", 5, 0.7, 5, ["Original", "Plan2Align"]]
358
  ],
359
  inputs=[
360
  source_text,
@@ -363,7 +439,8 @@ with gr.Blocks(title="Test-Time Machine Translation with Plan2Align") as demo:
363
  max_iterations_input,
364
  threshold_input,
365
  good_ref_contexts_num_input,
366
- translation_methods_input
 
367
  ],
368
  outputs=[original_output, plan2align_output, best_of_n_output, mpc_output],
369
  fn=process_text
 
69
  torch_dtype=torch.float16
70
  )
71
 
72
+ import spacy
73
+ lang_map = {
74
+ "English": ("en", "en_core_web_sm"),
75
+ "Russian": ("ru", "ru_core_news_sm"),
76
+ "German": ("de", "de_core_news_sm"),
77
+ "Japanese": ("ja", "ja_core_news_sm"),
78
+ "Korean": ("ko", "ko_core_news_sm"),
79
+ "Spanish": ("es", "es_core_news_sm"),
80
+ "Simplified Chinese": ("zh", "zh_core_web_sm"),
81
+ "Traditional Chinese": ("zh", "zh_core_web_sm")
82
+ }
83
+
84
+ def get_lang_and_nlp(language):
85
+ if language not in lang_map:
86
+ raise ValueError(f"Unsupported language: {language}")
87
+ lang_code, model_name = lang_map[language]
88
+ return lang_code, spacy.load(model_name)
89
+
90
+ def segment_sentences_by_punctuation(text, src_nlp):
91
+ segmented_sentences = []
92
+ paragraphs = text.split('\n')
93
+ for paragraph in paragraphs:
94
+ if paragraph.strip():
95
+ doc = src_nlp(paragraph)
96
+ for sent in doc.sents:
97
+ segmented_sentences.append(sent.text.strip())
98
+ return segmented_sentences
99
 
100
  def generate_translation(system_prompt, prompt):
101
  full_prompt = f"System: {system_prompt}\nUser: {prompt}\nAssistant:"
 
172
  translations.append(translation)
173
  return translations
174
 
175
+ def summary_translate(src_text, temp_tgt_text, tgt_language):
176
+ system_prompts = ["You are a helpful rephraser. You only output the rephrased result."]
177
+ translations = []
178
+ for prompt_style in system_prompts:
179
+ prompt = f"### Rephrase the following in {tgt_language}."
180
+ prompt += f"\n### Input:\n {textemp_tgt_textt}"
181
+ prompt += f"\n### Rephrased:\n"
182
+ translation = generate_translation(prompt_style, prompt)
183
+ translations.append(translation)
184
+
185
+ best, score = evaluate_candidates(src_text, translations, target_language, session_id)
186
+ if cand_list:
187
+ return best, score
188
+ return "", 0
189
+
190
  def plan2align_translate_text(text, session_id, model, tokenizer, device, src_language, task_language, max_iterations_value, threshold_value, good_ref_contexts_num_value, reward_model_type):
191
  result = translate_text(
192
  text = text,
 
282
  best_of_n_output = ""
283
  mpc_output = ""
284
 
285
+ src_lang, src_nlp = get_lang_and_nlp(src_language)
286
+ source_sentence = text.replace("\n", " ")
287
+ source_segments = segment_sentences_by_punctuation(source_sentence, src_nlp)
288
+
289
+ if chunk_size == -1:
290
+ chunks = [' '.join(source_segments)]
291
+ else:
292
+ chunks = [' '.join(source_segments[i:i+chunk_size]) for i in range(0, len(source_segments), chunk_size)]
293
+
294
+ org_translated_chunks = []
295
+ p2a_translated_chunks = []
296
+ bfn_translated_chunks = []
297
+ mpc_translated_chunks = []
298
+
299
+ for chunk in chunks:
300
+ if "Original" in translation_methods:
301
+ translation, _ = original_translation(chunk, src_language, target_language, session_id)
302
+ org_translated_chunks.append(translation)
303
+ if "Plan2Align" in translation_methods:
304
+ translation, _ = plan2align_translate_text(
305
+ chunk, session_id, model, tokenizer, device, src_language, target_language,
306
+ max_iterations_value, threshold_value, good_ref_contexts_num_value, "metricx"
307
+ )
308
+ p2a_translated_chunks.append(translation)
309
+ if "Best-of-N" in translation_methods:
310
+ translation, _ = best_of_n_translation(chunk, src_language, target_language, max_iterations_value, session_id)
311
+ bfn_translated_chunks.append(translation)
312
+ if "MPC" in translation_methods:
313
+ translation, _ = mpc_translation(chunk, src_language, target_language, max_iterations_value, session_id)
314
+ mpc_translated_chunks.append(translation)
315
+
316
+ org_combined_translation = ' '.join(org_translated_chunks)
317
+ p2a_combined_translation = ' '.join(p2a_translated_chunks)
318
+ bfn_combined_translation = ' '.join(bfn_translated_chunks)
319
+ mpc_combined_translation = ' '.join(mpc_translated_chunks)
320
+
321
+ orig, best_score = summary_translate(org_combined_translation, target_language)
322
+ orig_output = f"{orig}\n\nScore: {best_score:.2f}"
323
+ plan2align_trans, best_score = summary_translate(p2a_combined_translation, target_language)
324
+ plan2align_output = f"{plan2align_trans}\n\nScore: {best_score:.2f}"
325
+ best_candidate, best_score = summary_translate(bfn_combined_translation, target_language)
326
+ best_of_n_output = f"{best_candidate}\n\nScore: {best_score:.2f}"
327
+ mpc_candidate, best_score = summary_translate(mpc_combined_translation, target_language)
328
+ mpc_output = f"{mpc_candidate}\n\nScore: {mpc_score:.2f}"
329
+
330
+ # if "Original" in translation_methods:
331
+ # orig, best_score = original_translation(text, src_language, target_language, session_id)
332
+ # orig_output = f"{orig}\n\nScore: {best_score:.2f}"
333
+ # if "Plan2Align" in translation_methods:
334
+ # plan2align_trans, best_score = plan2align_translate_text(
335
+ # text, session_id, model, tokenizer, device, src_language, target_language,
336
+ # max_iterations_value, threshold_value, good_ref_contexts_num_value, "metricx"
337
+ # )
338
+ # plan2align_output = f"{plan2align_trans}\n\nScore: {best_score:.2f}"
339
+ # if "Best-of-N" in translation_methods:
340
+ # best_candidate, best_score = best_of_n_translation(text, src_language, target_language, max_iterations_value, session_id)
341
+ # best_of_n_output = f"{best_candidate}\n\nScore: {best_score:.2f}"
342
+ # if "MPC" in translation_methods:
343
+ # mpc_candidate, mpc_score = mpc_translation(text, src_language, target_language,
344
+ # max_iterations_value, session_id)
345
+ # mpc_output = f"{mpc_candidate}\n\nScore: {mpc_score:.2f}"
346
 
347
  return orig_output, plan2align_output, best_of_n_output, mpc_output
348
 
 
381
  value=["Original", "Plan2Align"],
382
  label="Translation Methods"
383
  )
384
+ chunk_size_input = gr.Number( # ✅ add chunk function
385
+ label="Chunk Size (Number of sentences per translation, -1 for all)",
386
+ value=-1
387
+ )
388
  translate_button = gr.Button("Translate")
389
  with gr.Column(scale=2):
390
  original_output = gr.Textbox(
 
418
  threshold_input,
419
  good_ref_contexts_num_input,
420
  translation_methods_input,
421
+ chunk_size_input, # ✅ add chunk function
422
  state
423
  ],
424
  outputs=[original_output, plan2align_output, best_of_n_output, mpc_output]
 
426
 
427
  gr.Examples(
428
  examples=[
429
+ ["台灣夜市文化豐富多彩,從士林夜市到饒河街夜市,提供各種美食、遊戲和購物體驗,吸引了無數遊客。", "Traditional Chinese", "English", 2, 0.7, 1, ["Original", "Plan2Align"], -1],
430
+ ["台北101曾經是世界最高的建築物,它不僅是台灣的地標,也象徵著經濟成就和創新精神。", "Traditional Chinese", "Russian", 2, 0.7, 1, ["Original", "Plan2Align"], -1],
431
+ ["阿里山日出和森林鐵路是台灣最著名的自然景觀之一,每年吸引數十萬遊客前來欣賞雲海和壯麗的日出。", "Traditional Chinese", "German", 2, 0.7, 1, ["Original", "Plan2Align"], -1],
432
+ ["珍珠奶茶,這款源自台灣的獨特飲品,不僅在台灣本地深受喜愛,更以其獨特的風味和口感,在全球掀起了一股熱潮,成為了一種跨越文化、風靡全球的時尚飲品。", "Traditional Chinese", "Japanese", 3, 0.7, 3, ["Original", "Plan2Align"], -1],
433
+ ["原住民文化如同一片深邃的星空,閃爍著無數璀璨的傳統與藝術光芒。他們的歌舞,是與祖靈對話的旋律,是與自然共鳴的節奏,每一個舞步、每一聲吟唱,都承載著古老的傳說與智慧。編織,是他們巧手下的詩篇,一絲一線,交織出生命的紋理,也編織出對土地的熱愛與敬畏。木雕,則是他們與自然對話的雕塑,每一刀、每一鑿,都刻畫著對萬物的觀察與敬意,也雕琢出對祖先的追憶與傳承。", "Traditional Chinese", "Korean", 5, 0.7, 5, ["Original", "Plan2Align"], -1]
434
  ],
435
  inputs=[
436
  source_text,
 
439
  max_iterations_input,
440
  threshold_input,
441
  good_ref_contexts_num_input,
442
+ translation_methods_input,
443
+ chunk_size_input # ✅ add chunk function
444
  ],
445
  outputs=[original_output, plan2align_output, best_of_n_output, mpc_output],
446
  fn=process_text