KuangDW commited on
Commit
946f7f8
·
1 Parent(s): dd05f29

specify local llm

Browse files
Files changed (2) hide show
  1. app.py +30 -17
  2. vecalign/plan2align.py +31 -134
app.py CHANGED
@@ -9,14 +9,14 @@ from openai import OpenAI
9
  from vecalign.plan2align import translate_text, external_find_best_translation
10
  from transformers import AutoTokenizer, AutoModelForCausalLM
11
  from trl import AutoModelForCausalLMWithValueHead
12
- from huggingface_hub import login
13
  import spacy
14
  import subprocess
15
  import pkg_resources
16
  import sys
17
 
18
  laser_token = os.environ.get("align_enc")
19
- laser_path = snapshot_download(repo_id="KuangDW/laser", use_auth_token=hf_token)
20
  os.environ["LASER"] = laser_path
21
 
22
  def check_and_install(package, required_version):
@@ -54,21 +54,35 @@ except OSError:
54
  download("zh_core_web_sm")
55
  subprocess.check_call([sys.executable, "-m", "pip", "install", "numpy==1.24.0", "--force-reinstall"])
56
 
57
- # ---------- deepinfra translation ----------
58
- openai = OpenAI(
59
- api_key="",
60
- base_url="https://api.deepinfra.com/v1/openai",
 
 
 
 
 
 
 
 
 
61
  )
62
 
63
  def generate_translation(system_prompt, prompt):
64
- response = openai.chat.completions.create(
65
- model="meta-llama/Meta-Llama-3.1-8B-Instruct",
66
- messages=[
67
- {"role": "system", "content": system_prompt},
68
- {"role": "user", "content": prompt}
69
- ]
70
- )
71
- translation = response.choices[0].message.content.strip()
 
 
 
 
 
72
  return translation
73
 
74
  def check_token_length(text, max_tokens=1024):
@@ -188,7 +202,7 @@ def mpc_translation(text, src_language, target_language, iterations, session_id)
188
  best_score = score
189
  return current_trans, best_score
190
 
191
- # ---------- Gradio 主流程函數 ----------
192
 
193
  def process_text(text, src_language, target_language, max_iterations_value, threshold_value,
194
  good_ref_contexts_num_value, translation_methods, state):
@@ -202,7 +216,6 @@ def process_text(text, src_language, target_language, max_iterations_value, thre
202
  4. MPC 翻譯
203
  """
204
 
205
- # 初始化各輸出內容
206
  orig_output = ""
207
  plan2align_output = ""
208
  best_of_n_output = ""
@@ -214,7 +227,7 @@ def process_text(text, src_language, target_language, max_iterations_value, thre
214
  orig_output = f"{orig}\n\nScore: {best_score:.2f}"
215
  if "Plan2Align" in translation_methods:
216
  plan2align_trans, best_score = plan2align_translate_text(
217
- text, session_id, src_language, target_language,
218
  max_iterations_value, threshold_value, good_ref_contexts_num_value, "metricx"
219
  )
220
  plan2align_output = f"{plan2align_trans}\n\nScore: {best_score:.2f}"
 
9
  from vecalign.plan2align import translate_text, external_find_best_translation
10
  from transformers import AutoTokenizer, AutoModelForCausalLM
11
  from trl import AutoModelForCausalLMWithValueHead
12
+ from huggingface_hub import login, HfApi, snapshot_download
13
  import spacy
14
  import subprocess
15
  import pkg_resources
16
  import sys
17
 
18
  laser_token = os.environ.get("align_enc")
19
+ laser_path = snapshot_download(repo_id="KuangDW/laser", use_auth_token=laser_token)
20
  os.environ["LASER"] = laser_path
21
 
22
  def check_and_install(package, required_version):
 
54
  download("zh_core_web_sm")
55
  subprocess.check_call([sys.executable, "-m", "pip", "install", "numpy==1.24.0", "--force-reinstall"])
56
 
57
+ # ---------- translation function ----------
58
+
59
+ # Initialize device
60
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
61
+ print(f"Using device: {device}")
62
+ # Load models once
63
+ print("Loading models...")
64
+ model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
65
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
66
+ model = AutoModelForCausalLM.from_pretrained(
67
+ model_id,
68
+ device_map="auto",
69
+ torch_dtype=torch.float16
70
  )
71
 
72
  def generate_translation(system_prompt, prompt):
73
+ messages=[
74
+ {"role": "system", "content": system_prompt},
75
+ {"role": "user", "content": prompt}
76
+ ]
77
+ inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)
78
+ outputs = model.generate(
79
+ inputs,
80
+ max_new_tokens=512,
81
+ temperature=0.7,
82
+ top_p=0.9,
83
+ do_sample=True
84
+ )
85
+ translation = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
86
  return translation
87
 
88
  def check_token_length(text, max_tokens=1024):
 
202
  best_score = score
203
  return current_trans, best_score
204
 
205
+ # ---------- Gradio function ----------
206
 
207
  def process_text(text, src_language, target_language, max_iterations_value, threshold_value,
208
  good_ref_contexts_num_value, translation_methods, state):
 
216
  4. MPC 翻譯
217
  """
218
 
 
219
  orig_output = ""
220
  plan2align_output = ""
221
  best_of_n_output = ""
 
227
  orig_output = f"{orig}\n\nScore: {best_score:.2f}"
228
  if "Plan2Align" in translation_methods:
229
  plan2align_trans, best_score = plan2align_translate_text(
230
+ text, session_id, model, tokenizer, device, src_language, target_language,
231
  max_iterations_value, threshold_value, good_ref_contexts_num_value, "metricx"
232
  )
233
  plan2align_output = f"{plan2align_trans}\n\nScore: {best_score:.2f}"
vecalign/plan2align.py CHANGED
@@ -28,12 +28,6 @@ lang_map = {
28
  "Chinese": ("zh", "zh_core_web_sm")
29
  }
30
 
31
- openai = OpenAI(
32
- api_key="",
33
- base_url="https://api.deepinfra.com/v1/openai",
34
- )
35
- MODEL_NAME= "google/gemma-2-9b-it" # "meta-llama/Meta-Llama-3.1-8B-Instruct"
36
-
37
  ################################# folder / file processing #################################
38
 
39
  def clear_folder(folder_path):
@@ -180,7 +174,7 @@ def external_find_best_translation(evals, language, session_id):
180
 
181
  ################################# generating translation #################################
182
 
183
- def translate_with_deepinfra(source_sentence, buffer, good_sent_size, src_language, tgt_language):
184
  system_prompts = [
185
  "You are a meticulous translator. Provide a literal, word-for-word translation that preserves the structure and meaning of each individual word.",
186
  "You are a professional translator. Deliver a clear, formal, and precise translation that faithfully conveys the original meaning.",
@@ -227,14 +221,19 @@ def translate_with_deepinfra(source_sentence, buffer, good_sent_size, src_langua
227
 
228
  translations = []
229
  for prompt in system_prompts:
230
- response = openai.chat.completions.create(
231
- model=MODEL_NAME,
232
- messages=[
233
- {"role": "system", "content": prompt},
234
- {"role": "user", "content": context_prompt}
235
- ]
 
 
 
 
 
236
  )
237
- translation = response.choices[0].message.content.strip()
238
 
239
  print("--------------------------------------------------------------------------------")
240
  print("\n rollout translation: \n")
@@ -264,7 +263,7 @@ def process_buffer_sentences(source_sentences, buffer):
264
  translations.append(translation_map[src_sent][0])
265
  return translations
266
 
267
- def final_translate_with_deepinfra(source_sentence, source_segments, buffer, src_language, tgt_language):
268
  translations = process_buffer_sentences(source_segments, buffer)
269
  initial_translation = "\n".join(translations)
270
 
@@ -286,21 +285,23 @@ def final_translate_with_deepinfra(source_sentence, source_segments, buffer, src
286
 
287
  print("rewrite prompt:")
288
  print(rewrite_prompt)
289
-
290
- rewrite_response = openai.chat.completions.create(
291
- model=MODEL_NAME, # Replace with your actual model name
292
- messages=[
293
- {"role": "system", "content": "You are a helpful translator and only output the result."},
294
- {"role": "user", "content": rewrite_prompt}
295
- ]
296
- )
297
- translation = rewrite_response.choices[0].message.content.strip()
 
 
 
 
298
  return translation
299
 
300
 
301
  ################################# alignment functions #################################
302
-
303
-
304
  def save_sentences_to_txt(sentences, filename):
305
  i = 0
306
  with open(filename, "w", encoding="utf-8") as file:
@@ -558,111 +559,13 @@ def generate_windows(source, translations):
558
 
559
  ################################# main function #################################
560
 
561
- def saving_memory(buffer, index, iteration, final_translations_record):
562
- """
563
- Save the buffer, and final_translations_record to the Memory folder.
564
- """
565
- current_dir = os.path.dirname(os.path.abspath(__file__))
566
- memory_folder = os.path.join(current_dir, f"{MEMORY_FOLDER}")
567
- os.makedirs(memory_folder, exist_ok=True)
568
- buffer_file_path = f"{MEMORY_FOLDER}/buffer_{index}_iter_{iteration}.json"
569
- metadata_file_path = f"{MEMORY_FOLDER}/metadata_{index}_iter_{iteration}.json"
570
-
571
- buffer_to_save = {key: list(value) for key, value in buffer.items()}
572
- with open(buffer_file_path, "w", encoding="utf-8") as f:
573
- json.dump(buffer_to_save, f, ensure_ascii=False, indent=4)
574
-
575
- metadata = {
576
- "final_translations_record": final_translations_record
577
- }
578
- with open(metadata_file_path, "w", encoding="utf-8") as f:
579
- json.dump(metadata, f, ensure_ascii=False, indent=4)
580
-
581
- print(f"Buffer saved to {buffer_file_path}")
582
- print(f"Metadata saved to {metadata_file_path}")
583
-
584
-
585
- def process_chunk():
586
-
587
- data = pd.read_csv(csv_path)
588
- for index, row in data.iterrows():
589
- print("::::::::::::::::::::::: index :::::::::::::::::::::::", index, " ::::::::::::::::::::::: index :::::::::::::::::::::::", )
590
- buffer = defaultdict(list)
591
-
592
- source_sentence = row[src_lang].replace('\n', ' ')
593
- source_segments = segment_sentences_by_punctuation(source_sentence, lang=src_lang)
594
-
595
- for iteration in range(max_iterations):
596
- print(f"\nStarting iteration {iteration + 1}/{max_iterations}...\n")
597
-
598
- if iteration in stop_memory:
599
- final_translations = final_translate_with_deepinfra(source_sentence, source_segments, buffer, SRC_LANGUAGE, TASK_LANGUAGE)
600
- print("Final Translation Method:")
601
- print(final_translations)
602
- final_translations_record = [final_translations]
603
- saving_memory(buffer, index, iteration, final_translations_record)
604
-
605
- if iteration == max_iterations - 1:
606
- break
607
- else:
608
- translations = translate_with_deepinfra(source_sentence, buffer, good_ref_contexts_num+iteration, SRC_LANGUAGE, TASK_LANGUAGE)
609
-
610
- src_windows, mt_windows_list = generate_windows(source_sentence, translations)
611
-
612
- ####################################### Evaluate translations and update buffer #######################################
613
- print("Evaluate translations and update buffer ..............")
614
-
615
- # First, store all sources and candidate translations as lists.
616
- src_context_list = list(src_windows)
617
- candidates_list = []
618
- for window_index in range(len(src_windows)):
619
- candidates = [mt_windows[window_index] for mt_windows in mt_windows_list]
620
- candidates_list.append(candidates)
621
-
622
- # Batch evaluate all candidate translations, returning the best translation and score for each source.
623
- best_candidate_results = batch_rm_find_best_translation(list(zip(src_context_list, candidates_list)), TASK_LANGUAGE)
624
-
625
- print("\n Our best candidate results:")
626
- print(best_candidate_results)
627
- print(" ------------------------------------------------------------------------ \n")
628
-
629
- print("\n===== Initial buffer state =====")
630
- for src, translations in buffer.items():
631
- print(f"Source '{src}': {[t[0] for t in translations]}")
632
-
633
- # Update the buffer for each source.
634
- for i, src in enumerate(src_context_list):
635
- best_tuple = best_candidate_results[i] # (translation, score)
636
- if best_tuple[0] is not None:
637
- # If the source is not yet in the buffer, initialize it.
638
- if src not in buffer:
639
- buffer[src] = [best_tuple]
640
- print(f"[ADD] New Source '{src}' Add Translation: '{best_tuple[0]}', Score: {best_tuple[1]}")
641
- else:
642
- # Directly add the new translation to the buffer.
643
- buffer[src].append(best_tuple)
644
- print(f"[ADD] Source '{src}' Add Translation: '{best_tuple[0]}', Score: {best_tuple[1]}")
645
-
646
- # Sort by score to place the best translation (highest score) at the top.
647
- buffer[src].sort(key=lambda x: x[1], reverse=True)
648
- print(f"[UPDATE] Source '{src}' Best Translation: '{buffer[src][0][0]}'")
649
-
650
- print("\n===== Final buffer state =====")
651
- for src, translations in buffer.items():
652
- print(f"Source '{src}': {[t[0] for t in translations]}")
653
-
654
-
655
- print("Final Translation:")
656
- print(final_translations)
657
-
658
-
659
  def get_lang_and_nlp(language):
660
  if language not in lang_map:
661
  raise ValueError(f"Unsupported language: {language}")
662
  lang_code, model_name = lang_map[language]
663
  return lang_code, spacy.load(model_name)
664
 
665
- def translate_text(text, session_id,
666
  src_language="Japanese",
667
  task_language="English",
668
  max_iterations_value=3,
@@ -699,14 +602,12 @@ def translate_text(text, session_id,
699
  final_translations = None
700
 
701
  for iteration in range(max_iterations):
702
- # print(f"\nStarting iteration {iteration + 1}/{max_iterations}...\n")
703
  if iteration in stop_memory:
704
- final_translations = final_translate_with_deepinfra(source_sentence, source_segments, buffer, SRC_LANGUAGE, TASK_LANGUAGE)
705
- # saving_memory(buffer, 0, iteration, [final_translations])
706
  if iteration == max_iterations - 1:
707
  break
708
  else:
709
- translations = translate_with_deepinfra(source_sentence, buffer, good_ref_contexts_num + iteration, SRC_LANGUAGE, TASK_LANGUAGE)
710
 
711
  src_windows, mt_windows_list = generate_windows(source_sentence, translations)
712
  # print("Evaluate translations and update buffer ..............")
@@ -741,8 +642,4 @@ def translate_text(text, session_id,
741
 
742
  # print("Final Translation:")
743
  # print(final_translations)
744
- return final_translations
745
-
746
-
747
- if __name__ == "__main__":
748
- process_chunk()
 
28
  "Chinese": ("zh", "zh_core_web_sm")
29
  }
30
 
 
 
 
 
 
 
31
  ################################# folder / file processing #################################
32
 
33
  def clear_folder(folder_path):
 
174
 
175
  ################################# generating translation #################################
176
 
177
+ def translate_with_deepinfra(model, tokenizer, device, source_sentence, buffer, good_sent_size, src_language, tgt_language):
178
  system_prompts = [
179
  "You are a meticulous translator. Provide a literal, word-for-word translation that preserves the structure and meaning of each individual word.",
180
  "You are a professional translator. Deliver a clear, formal, and precise translation that faithfully conveys the original meaning.",
 
221
 
222
  translations = []
223
  for prompt in system_prompts:
224
+ messages=[
225
+ {"role": "system", "content": prompt},
226
+ {"role": "user", "content": context_prompt}
227
+ ]
228
+ inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)
229
+ outputs = model.generate(
230
+ inputs,
231
+ max_new_tokens=512,
232
+ temperature=0.7,
233
+ top_p=0.9,
234
+ do_sample=True
235
  )
236
+ translation = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
237
 
238
  print("--------------------------------------------------------------------------------")
239
  print("\n rollout translation: \n")
 
263
  translations.append(translation_map[src_sent][0])
264
  return translations
265
 
266
+ def final_translate_with_deepinfra(model, tokenizer, device, source_sentence, source_segments, buffer, src_language, tgt_language):
267
  translations = process_buffer_sentences(source_segments, buffer)
268
  initial_translation = "\n".join(translations)
269
 
 
285
 
286
  print("rewrite prompt:")
287
  print(rewrite_prompt)
288
+ messages=[
289
+ {"role": "system", "content": "You are a helpful translator and only output the result."},
290
+ {"role": "user", "content": rewrite_prompt}
291
+ ]
292
+ inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)
293
+ outputs = model.generate(
294
+ inputs,
295
+ max_new_tokens=512,
296
+ temperature=0.7,
297
+ top_p=0.9,
298
+ do_sample=True
299
+ )
300
+ translation = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
301
  return translation
302
 
303
 
304
  ################################# alignment functions #################################
 
 
305
  def save_sentences_to_txt(sentences, filename):
306
  i = 0
307
  with open(filename, "w", encoding="utf-8") as file:
 
559
 
560
  ################################# main function #################################
561
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
562
  def get_lang_and_nlp(language):
563
  if language not in lang_map:
564
  raise ValueError(f"Unsupported language: {language}")
565
  lang_code, model_name = lang_map[language]
566
  return lang_code, spacy.load(model_name)
567
 
568
+ def translate_text(text, session_id, model, tokenizer, device,
569
  src_language="Japanese",
570
  task_language="English",
571
  max_iterations_value=3,
 
602
  final_translations = None
603
 
604
  for iteration in range(max_iterations):
 
605
  if iteration in stop_memory:
606
+ final_translations = final_translate_with_deepinfra(model, tokenizer, device, source_sentence, source_segments, buffer, SRC_LANGUAGE, TASK_LANGUAGE)
 
607
  if iteration == max_iterations - 1:
608
  break
609
  else:
610
+ translations = translate_with_deepinfra(model, tokenizer, device, source_sentence, buffer, good_ref_contexts_num + iteration, SRC_LANGUAGE, TASK_LANGUAGE)
611
 
612
  src_windows, mt_windows_list = generate_windows(source_sentence, translations)
613
  # print("Evaluate translations and update buffer ..............")
 
642
 
643
  # print("Final Translation:")
644
  # print(final_translations)
645
+ return final_translations