Spaces:
Sleeping
Sleeping
KuangDW
commited on
Commit
·
946f7f8
1
Parent(s):
dd05f29
specify local llm
Browse files- app.py +30 -17
- vecalign/plan2align.py +31 -134
app.py
CHANGED
@@ -9,14 +9,14 @@ from openai import OpenAI
|
|
9 |
from vecalign.plan2align import translate_text, external_find_best_translation
|
10 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
11 |
from trl import AutoModelForCausalLMWithValueHead
|
12 |
-
from huggingface_hub import login
|
13 |
import spacy
|
14 |
import subprocess
|
15 |
import pkg_resources
|
16 |
import sys
|
17 |
|
18 |
laser_token = os.environ.get("align_enc")
|
19 |
-
laser_path = snapshot_download(repo_id="KuangDW/laser", use_auth_token=
|
20 |
os.environ["LASER"] = laser_path
|
21 |
|
22 |
def check_and_install(package, required_version):
|
@@ -54,21 +54,35 @@ except OSError:
|
|
54 |
download("zh_core_web_sm")
|
55 |
subprocess.check_call([sys.executable, "-m", "pip", "install", "numpy==1.24.0", "--force-reinstall"])
|
56 |
|
57 |
-
# ----------
|
58 |
-
|
59 |
-
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
)
|
62 |
|
63 |
def generate_translation(system_prompt, prompt):
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
72 |
return translation
|
73 |
|
74 |
def check_token_length(text, max_tokens=1024):
|
@@ -188,7 +202,7 @@ def mpc_translation(text, src_language, target_language, iterations, session_id)
|
|
188 |
best_score = score
|
189 |
return current_trans, best_score
|
190 |
|
191 |
-
# ---------- Gradio
|
192 |
|
193 |
def process_text(text, src_language, target_language, max_iterations_value, threshold_value,
|
194 |
good_ref_contexts_num_value, translation_methods, state):
|
@@ -202,7 +216,6 @@ def process_text(text, src_language, target_language, max_iterations_value, thre
|
|
202 |
4. MPC 翻譯
|
203 |
"""
|
204 |
|
205 |
-
# 初始化各輸出內容
|
206 |
orig_output = ""
|
207 |
plan2align_output = ""
|
208 |
best_of_n_output = ""
|
@@ -214,7 +227,7 @@ def process_text(text, src_language, target_language, max_iterations_value, thre
|
|
214 |
orig_output = f"{orig}\n\nScore: {best_score:.2f}"
|
215 |
if "Plan2Align" in translation_methods:
|
216 |
plan2align_trans, best_score = plan2align_translate_text(
|
217 |
-
text, session_id, src_language, target_language,
|
218 |
max_iterations_value, threshold_value, good_ref_contexts_num_value, "metricx"
|
219 |
)
|
220 |
plan2align_output = f"{plan2align_trans}\n\nScore: {best_score:.2f}"
|
|
|
9 |
from vecalign.plan2align import translate_text, external_find_best_translation
|
10 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
11 |
from trl import AutoModelForCausalLMWithValueHead
|
12 |
+
from huggingface_hub import login, HfApi, snapshot_download
|
13 |
import spacy
|
14 |
import subprocess
|
15 |
import pkg_resources
|
16 |
import sys
|
17 |
|
18 |
laser_token = os.environ.get("align_enc")
|
19 |
+
laser_path = snapshot_download(repo_id="KuangDW/laser", use_auth_token=laser_token)
|
20 |
os.environ["LASER"] = laser_path
|
21 |
|
22 |
def check_and_install(package, required_version):
|
|
|
54 |
download("zh_core_web_sm")
|
55 |
subprocess.check_call([sys.executable, "-m", "pip", "install", "numpy==1.24.0", "--force-reinstall"])
|
56 |
|
57 |
+
# ---------- translation function ----------
|
58 |
+
|
59 |
+
# Initialize device
|
60 |
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
61 |
+
print(f"Using device: {device}")
|
62 |
+
# Load models once
|
63 |
+
print("Loading models...")
|
64 |
+
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
65 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
66 |
+
model = AutoModelForCausalLM.from_pretrained(
|
67 |
+
model_id,
|
68 |
+
device_map="auto",
|
69 |
+
torch_dtype=torch.float16
|
70 |
)
|
71 |
|
72 |
def generate_translation(system_prompt, prompt):
|
73 |
+
messages=[
|
74 |
+
{"role": "system", "content": system_prompt},
|
75 |
+
{"role": "user", "content": prompt}
|
76 |
+
]
|
77 |
+
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)
|
78 |
+
outputs = model.generate(
|
79 |
+
inputs,
|
80 |
+
max_new_tokens=512,
|
81 |
+
temperature=0.7,
|
82 |
+
top_p=0.9,
|
83 |
+
do_sample=True
|
84 |
+
)
|
85 |
+
translation = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
|
86 |
return translation
|
87 |
|
88 |
def check_token_length(text, max_tokens=1024):
|
|
|
202 |
best_score = score
|
203 |
return current_trans, best_score
|
204 |
|
205 |
+
# ---------- Gradio function ----------
|
206 |
|
207 |
def process_text(text, src_language, target_language, max_iterations_value, threshold_value,
|
208 |
good_ref_contexts_num_value, translation_methods, state):
|
|
|
216 |
4. MPC 翻譯
|
217 |
"""
|
218 |
|
|
|
219 |
orig_output = ""
|
220 |
plan2align_output = ""
|
221 |
best_of_n_output = ""
|
|
|
227 |
orig_output = f"{orig}\n\nScore: {best_score:.2f}"
|
228 |
if "Plan2Align" in translation_methods:
|
229 |
plan2align_trans, best_score = plan2align_translate_text(
|
230 |
+
text, session_id, model, tokenizer, device, src_language, target_language,
|
231 |
max_iterations_value, threshold_value, good_ref_contexts_num_value, "metricx"
|
232 |
)
|
233 |
plan2align_output = f"{plan2align_trans}\n\nScore: {best_score:.2f}"
|
vecalign/plan2align.py
CHANGED
@@ -28,12 +28,6 @@ lang_map = {
|
|
28 |
"Chinese": ("zh", "zh_core_web_sm")
|
29 |
}
|
30 |
|
31 |
-
openai = OpenAI(
|
32 |
-
api_key="",
|
33 |
-
base_url="https://api.deepinfra.com/v1/openai",
|
34 |
-
)
|
35 |
-
MODEL_NAME= "google/gemma-2-9b-it" # "meta-llama/Meta-Llama-3.1-8B-Instruct"
|
36 |
-
|
37 |
################################# folder / file processing #################################
|
38 |
|
39 |
def clear_folder(folder_path):
|
@@ -180,7 +174,7 @@ def external_find_best_translation(evals, language, session_id):
|
|
180 |
|
181 |
################################# generating translation #################################
|
182 |
|
183 |
-
def translate_with_deepinfra(source_sentence, buffer, good_sent_size, src_language, tgt_language):
|
184 |
system_prompts = [
|
185 |
"You are a meticulous translator. Provide a literal, word-for-word translation that preserves the structure and meaning of each individual word.",
|
186 |
"You are a professional translator. Deliver a clear, formal, and precise translation that faithfully conveys the original meaning.",
|
@@ -227,14 +221,19 @@ def translate_with_deepinfra(source_sentence, buffer, good_sent_size, src_langua
|
|
227 |
|
228 |
translations = []
|
229 |
for prompt in system_prompts:
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
|
|
|
|
|
|
|
|
|
|
236 |
)
|
237 |
-
translation =
|
238 |
|
239 |
print("--------------------------------------------------------------------------------")
|
240 |
print("\n rollout translation: \n")
|
@@ -264,7 +263,7 @@ def process_buffer_sentences(source_sentences, buffer):
|
|
264 |
translations.append(translation_map[src_sent][0])
|
265 |
return translations
|
266 |
|
267 |
-
def final_translate_with_deepinfra(source_sentence, source_segments, buffer, src_language, tgt_language):
|
268 |
translations = process_buffer_sentences(source_segments, buffer)
|
269 |
initial_translation = "\n".join(translations)
|
270 |
|
@@ -286,21 +285,23 @@ def final_translate_with_deepinfra(source_sentence, source_segments, buffer, src
|
|
286 |
|
287 |
print("rewrite prompt:")
|
288 |
print(rewrite_prompt)
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
|
|
|
|
|
|
|
|
298 |
return translation
|
299 |
|
300 |
|
301 |
################################# alignment functions #################################
|
302 |
-
|
303 |
-
|
304 |
def save_sentences_to_txt(sentences, filename):
|
305 |
i = 0
|
306 |
with open(filename, "w", encoding="utf-8") as file:
|
@@ -558,111 +559,13 @@ def generate_windows(source, translations):
|
|
558 |
|
559 |
################################# main function #################################
|
560 |
|
561 |
-
def saving_memory(buffer, index, iteration, final_translations_record):
|
562 |
-
"""
|
563 |
-
Save the buffer, and final_translations_record to the Memory folder.
|
564 |
-
"""
|
565 |
-
current_dir = os.path.dirname(os.path.abspath(__file__))
|
566 |
-
memory_folder = os.path.join(current_dir, f"{MEMORY_FOLDER}")
|
567 |
-
os.makedirs(memory_folder, exist_ok=True)
|
568 |
-
buffer_file_path = f"{MEMORY_FOLDER}/buffer_{index}_iter_{iteration}.json"
|
569 |
-
metadata_file_path = f"{MEMORY_FOLDER}/metadata_{index}_iter_{iteration}.json"
|
570 |
-
|
571 |
-
buffer_to_save = {key: list(value) for key, value in buffer.items()}
|
572 |
-
with open(buffer_file_path, "w", encoding="utf-8") as f:
|
573 |
-
json.dump(buffer_to_save, f, ensure_ascii=False, indent=4)
|
574 |
-
|
575 |
-
metadata = {
|
576 |
-
"final_translations_record": final_translations_record
|
577 |
-
}
|
578 |
-
with open(metadata_file_path, "w", encoding="utf-8") as f:
|
579 |
-
json.dump(metadata, f, ensure_ascii=False, indent=4)
|
580 |
-
|
581 |
-
print(f"Buffer saved to {buffer_file_path}")
|
582 |
-
print(f"Metadata saved to {metadata_file_path}")
|
583 |
-
|
584 |
-
|
585 |
-
def process_chunk():
|
586 |
-
|
587 |
-
data = pd.read_csv(csv_path)
|
588 |
-
for index, row in data.iterrows():
|
589 |
-
print("::::::::::::::::::::::: index :::::::::::::::::::::::", index, " ::::::::::::::::::::::: index :::::::::::::::::::::::", )
|
590 |
-
buffer = defaultdict(list)
|
591 |
-
|
592 |
-
source_sentence = row[src_lang].replace('\n', ' ')
|
593 |
-
source_segments = segment_sentences_by_punctuation(source_sentence, lang=src_lang)
|
594 |
-
|
595 |
-
for iteration in range(max_iterations):
|
596 |
-
print(f"\nStarting iteration {iteration + 1}/{max_iterations}...\n")
|
597 |
-
|
598 |
-
if iteration in stop_memory:
|
599 |
-
final_translations = final_translate_with_deepinfra(source_sentence, source_segments, buffer, SRC_LANGUAGE, TASK_LANGUAGE)
|
600 |
-
print("Final Translation Method:")
|
601 |
-
print(final_translations)
|
602 |
-
final_translations_record = [final_translations]
|
603 |
-
saving_memory(buffer, index, iteration, final_translations_record)
|
604 |
-
|
605 |
-
if iteration == max_iterations - 1:
|
606 |
-
break
|
607 |
-
else:
|
608 |
-
translations = translate_with_deepinfra(source_sentence, buffer, good_ref_contexts_num+iteration, SRC_LANGUAGE, TASK_LANGUAGE)
|
609 |
-
|
610 |
-
src_windows, mt_windows_list = generate_windows(source_sentence, translations)
|
611 |
-
|
612 |
-
####################################### Evaluate translations and update buffer #######################################
|
613 |
-
print("Evaluate translations and update buffer ..............")
|
614 |
-
|
615 |
-
# First, store all sources and candidate translations as lists.
|
616 |
-
src_context_list = list(src_windows)
|
617 |
-
candidates_list = []
|
618 |
-
for window_index in range(len(src_windows)):
|
619 |
-
candidates = [mt_windows[window_index] for mt_windows in mt_windows_list]
|
620 |
-
candidates_list.append(candidates)
|
621 |
-
|
622 |
-
# Batch evaluate all candidate translations, returning the best translation and score for each source.
|
623 |
-
best_candidate_results = batch_rm_find_best_translation(list(zip(src_context_list, candidates_list)), TASK_LANGUAGE)
|
624 |
-
|
625 |
-
print("\n Our best candidate results:")
|
626 |
-
print(best_candidate_results)
|
627 |
-
print(" ------------------------------------------------------------------------ \n")
|
628 |
-
|
629 |
-
print("\n===== Initial buffer state =====")
|
630 |
-
for src, translations in buffer.items():
|
631 |
-
print(f"Source '{src}': {[t[0] for t in translations]}")
|
632 |
-
|
633 |
-
# Update the buffer for each source.
|
634 |
-
for i, src in enumerate(src_context_list):
|
635 |
-
best_tuple = best_candidate_results[i] # (translation, score)
|
636 |
-
if best_tuple[0] is not None:
|
637 |
-
# If the source is not yet in the buffer, initialize it.
|
638 |
-
if src not in buffer:
|
639 |
-
buffer[src] = [best_tuple]
|
640 |
-
print(f"[ADD] New Source '{src}' Add Translation: '{best_tuple[0]}', Score: {best_tuple[1]}")
|
641 |
-
else:
|
642 |
-
# Directly add the new translation to the buffer.
|
643 |
-
buffer[src].append(best_tuple)
|
644 |
-
print(f"[ADD] Source '{src}' Add Translation: '{best_tuple[0]}', Score: {best_tuple[1]}")
|
645 |
-
|
646 |
-
# Sort by score to place the best translation (highest score) at the top.
|
647 |
-
buffer[src].sort(key=lambda x: x[1], reverse=True)
|
648 |
-
print(f"[UPDATE] Source '{src}' Best Translation: '{buffer[src][0][0]}'")
|
649 |
-
|
650 |
-
print("\n===== Final buffer state =====")
|
651 |
-
for src, translations in buffer.items():
|
652 |
-
print(f"Source '{src}': {[t[0] for t in translations]}")
|
653 |
-
|
654 |
-
|
655 |
-
print("Final Translation:")
|
656 |
-
print(final_translations)
|
657 |
-
|
658 |
-
|
659 |
def get_lang_and_nlp(language):
|
660 |
if language not in lang_map:
|
661 |
raise ValueError(f"Unsupported language: {language}")
|
662 |
lang_code, model_name = lang_map[language]
|
663 |
return lang_code, spacy.load(model_name)
|
664 |
|
665 |
-
def translate_text(text, session_id,
|
666 |
src_language="Japanese",
|
667 |
task_language="English",
|
668 |
max_iterations_value=3,
|
@@ -699,14 +602,12 @@ def translate_text(text, session_id,
|
|
699 |
final_translations = None
|
700 |
|
701 |
for iteration in range(max_iterations):
|
702 |
-
# print(f"\nStarting iteration {iteration + 1}/{max_iterations}...\n")
|
703 |
if iteration in stop_memory:
|
704 |
-
final_translations = final_translate_with_deepinfra(source_sentence, source_segments, buffer, SRC_LANGUAGE, TASK_LANGUAGE)
|
705 |
-
# saving_memory(buffer, 0, iteration, [final_translations])
|
706 |
if iteration == max_iterations - 1:
|
707 |
break
|
708 |
else:
|
709 |
-
translations = translate_with_deepinfra(source_sentence, buffer, good_ref_contexts_num + iteration, SRC_LANGUAGE, TASK_LANGUAGE)
|
710 |
|
711 |
src_windows, mt_windows_list = generate_windows(source_sentence, translations)
|
712 |
# print("Evaluate translations and update buffer ..............")
|
@@ -741,8 +642,4 @@ def translate_text(text, session_id,
|
|
741 |
|
742 |
# print("Final Translation:")
|
743 |
# print(final_translations)
|
744 |
-
return final_translations
|
745 |
-
|
746 |
-
|
747 |
-
if __name__ == "__main__":
|
748 |
-
process_chunk()
|
|
|
28 |
"Chinese": ("zh", "zh_core_web_sm")
|
29 |
}
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
################################# folder / file processing #################################
|
32 |
|
33 |
def clear_folder(folder_path):
|
|
|
174 |
|
175 |
################################# generating translation #################################
|
176 |
|
177 |
+
def translate_with_deepinfra(model, tokenizer, device, source_sentence, buffer, good_sent_size, src_language, tgt_language):
|
178 |
system_prompts = [
|
179 |
"You are a meticulous translator. Provide a literal, word-for-word translation that preserves the structure and meaning of each individual word.",
|
180 |
"You are a professional translator. Deliver a clear, formal, and precise translation that faithfully conveys the original meaning.",
|
|
|
221 |
|
222 |
translations = []
|
223 |
for prompt in system_prompts:
|
224 |
+
messages=[
|
225 |
+
{"role": "system", "content": prompt},
|
226 |
+
{"role": "user", "content": context_prompt}
|
227 |
+
]
|
228 |
+
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)
|
229 |
+
outputs = model.generate(
|
230 |
+
inputs,
|
231 |
+
max_new_tokens=512,
|
232 |
+
temperature=0.7,
|
233 |
+
top_p=0.9,
|
234 |
+
do_sample=True
|
235 |
)
|
236 |
+
translation = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
|
237 |
|
238 |
print("--------------------------------------------------------------------------------")
|
239 |
print("\n rollout translation: \n")
|
|
|
263 |
translations.append(translation_map[src_sent][0])
|
264 |
return translations
|
265 |
|
266 |
+
def final_translate_with_deepinfra(model, tokenizer, device, source_sentence, source_segments, buffer, src_language, tgt_language):
|
267 |
translations = process_buffer_sentences(source_segments, buffer)
|
268 |
initial_translation = "\n".join(translations)
|
269 |
|
|
|
285 |
|
286 |
print("rewrite prompt:")
|
287 |
print(rewrite_prompt)
|
288 |
+
messages=[
|
289 |
+
{"role": "system", "content": "You are a helpful translator and only output the result."},
|
290 |
+
{"role": "user", "content": rewrite_prompt}
|
291 |
+
]
|
292 |
+
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)
|
293 |
+
outputs = model.generate(
|
294 |
+
inputs,
|
295 |
+
max_new_tokens=512,
|
296 |
+
temperature=0.7,
|
297 |
+
top_p=0.9,
|
298 |
+
do_sample=True
|
299 |
+
)
|
300 |
+
translation = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
|
301 |
return translation
|
302 |
|
303 |
|
304 |
################################# alignment functions #################################
|
|
|
|
|
305 |
def save_sentences_to_txt(sentences, filename):
|
306 |
i = 0
|
307 |
with open(filename, "w", encoding="utf-8") as file:
|
|
|
559 |
|
560 |
################################# main function #################################
|
561 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
562 |
def get_lang_and_nlp(language):
|
563 |
if language not in lang_map:
|
564 |
raise ValueError(f"Unsupported language: {language}")
|
565 |
lang_code, model_name = lang_map[language]
|
566 |
return lang_code, spacy.load(model_name)
|
567 |
|
568 |
+
def translate_text(text, session_id, model, tokenizer, device,
|
569 |
src_language="Japanese",
|
570 |
task_language="English",
|
571 |
max_iterations_value=3,
|
|
|
602 |
final_translations = None
|
603 |
|
604 |
for iteration in range(max_iterations):
|
|
|
605 |
if iteration in stop_memory:
|
606 |
+
final_translations = final_translate_with_deepinfra(model, tokenizer, device, source_sentence, source_segments, buffer, SRC_LANGUAGE, TASK_LANGUAGE)
|
|
|
607 |
if iteration == max_iterations - 1:
|
608 |
break
|
609 |
else:
|
610 |
+
translations = translate_with_deepinfra(model, tokenizer, device, source_sentence, buffer, good_ref_contexts_num + iteration, SRC_LANGUAGE, TASK_LANGUAGE)
|
611 |
|
612 |
src_windows, mt_windows_list = generate_windows(source_sentence, translations)
|
613 |
# print("Evaluate translations and update buffer ..............")
|
|
|
642 |
|
643 |
# print("Final Translation:")
|
644 |
# print(final_translations)
|
645 |
+
return final_translations
|
|
|
|
|
|
|
|