Spaces:
Sleeping
Sleeping
import openai | |
from openai import OpenAI | |
import spacy | |
import pandas as pd | |
from collections import defaultdict | |
import random | |
import torch | |
import torch.nn as nn | |
from transformers import MT5Tokenizer, MT5ForConditionalGeneration | |
import shutil | |
import os | |
import subprocess | |
import json | |
from safetensors.torch import load_file | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from trl import AutoModelForCausalLMWithValueHead | |
from huggingface_hub import login | |
import logging | |
import argparse | |
lang_map = { | |
"English": ("en", "en_core_web_sm"), | |
"Russian": ("ru", "ru_core_news_sm"), | |
"German": ("de", "de_core_news_sm"), | |
"Japanese": ("ja", "ja_core_news_sm"), | |
"Korean": ("ko", "ko_core_news_sm"), | |
"Spanish": ("es", "es_core_news_sm"), | |
"Simplified Chinese": ("zh", "zh_core_web_sm"), | |
"Traditional Chinese": ("zh", "zh_core_web_sm") | |
} | |
################################# folder / file processing ################################# | |
def clear_folder(folder_path, session_id): | |
if not os.path.exists(folder_path): | |
os.makedirs(folder_path) | |
return | |
for filename in os.listdir(folder_path): | |
if filename.startswith(session_id): | |
file_path = os.path.join(folder_path, filename) | |
try: | |
if os.path.isfile(file_path) or os.path.islink(file_path): | |
os.remove(file_path) | |
elif os.path.isdir(file_path): | |
shutil.rmtree(file_path) | |
except Exception as e: | |
print(f"Failed to delete {file_path}. Reason: {e}") | |
def delete_files_with_mt(folder_path): | |
if not os.path.exists(folder_path): | |
print(f"Folder {folder_path} does not exist.") | |
return | |
for filename in os.listdir(folder_path): | |
if "mt" in filename: | |
file_path = os.path.join(folder_path, filename) | |
try: | |
if os.path.isfile(file_path): | |
os.remove(file_path) | |
print(f"Deleted file: {file_path}") | |
except Exception as e: | |
print(f"Failed to delete {file_path}. Reason: {e}") | |
################################# reward model for ranking ################################# | |
class metricx_RewardModel: | |
def __init__(self): | |
self.device = "cuda:0" | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
self.json_path = os.path.join(current_dir, f'json_for_metricx') | |
if not os.path.exists(self.json_path): | |
os.makedirs(self.json_path) | |
def get_entry(self, src, mt): | |
return {"source": src, "hypothesis": mt, "reference": ""} | |
def write_jsonl(self, src_list, mts, session_id): | |
with open(os.path.join(self.json_path, f"{session_id}_input.jsonl"), 'w', encoding='utf-8') as output_file: | |
for src, mt in zip(src_list, mts): | |
entry = self.get_entry(src, mt) | |
output_file.write(json.dumps(entry, ensure_ascii=False) + '\n') | |
def run_command(self, session_id): | |
devices_map = {'cuda:0':0, 'cuda:1':1, 'cuda:2':2, 'cuda:3':3} | |
command = [ | |
"python", "-m", "vecalign.metricx24.predict", | |
"--tokenizer", "google/mt5-large", | |
"--model_name_or_path", "google/metricx-24-hybrid-large-v2p6", | |
"--max_input_length", "1536", | |
"--batch_size", "1", | |
"--input_file", os.path.join(self.json_path, f"{session_id}_input.jsonl"), | |
"--output_file", os.path.join(self.json_path, f"{session_id}_output.jsonl"), | |
"--device", f"{devices_map.get(self.device, 0)}", | |
"--qe" | |
] | |
subprocess.run(command) | |
def get_predict(self, session_id): | |
scores = [] | |
with open(os.path.join(self.json_path, f"{session_id}_output.jsonl"), 'r', encoding='utf-8') as new_file: | |
for line in new_file: | |
entry = json.loads(line) | |
score = entry.get('prediction', None) | |
scores.append(score) | |
clear_folder(self.json_path, session_id) | |
return scores | |
def reward_fn_batch(self, language, src_list, mts, session_id): | |
self.write_jsonl(src_list, mts, session_id) | |
self.run_command(session_id) | |
scores = self.get_predict(session_id) | |
rewards = [1 - (score / 25) for score in scores] | |
return rewards | |
reward_model = metricx_RewardModel() | |
def batch_rm_find_best_translation(evals, language, session_id): | |
""" | |
evals: list of (src, [translation1, translation2, ...]) | |
Return the translation with the highest reward in each group that meets the THRESHOLD, along with its score. | |
Otherwise, return (None, score), where score is the highest score in that group. | |
""" | |
src_list = [] | |
mt_list = [] | |
counts = [] | |
for src, translations in evals: | |
counts.append(len(translations)) | |
for mt in translations: | |
src_list.append(src) | |
mt_list.append(mt) | |
rewards = reward_model.reward_fn_batch(language, src_list, mt_list, session_id) | |
print("rewards: ", rewards) | |
best_translations = [] | |
index = 0 | |
for (src, translations), count in zip(evals, counts): | |
group_rewards = rewards[index: index+count] | |
index += count | |
if count < 2: | |
if translations: | |
best_translations.append((translations[0], group_rewards[0])) | |
else: | |
best_translations.append((None, None)) | |
else: | |
best_index = group_rewards.index(max(group_rewards)) | |
best_score = group_rewards[best_index] | |
if best_score >= THRESHOLD: | |
best_translations.append((translations[best_index], best_score)) | |
else: | |
best_translations.append((None, best_score)) | |
return best_translations | |
def external_find_best_translation(evals, language, session_id): | |
""" | |
evals: list of (src, [translation1, translation2, ...]) | |
Return the translation with the highest reward in each group that meets the THRESHOLD, along with its score. | |
Otherwise, return (None, score), where score is the highest score in that group. | |
""" | |
src_list = [] | |
mt_list = [] | |
counts = [] | |
for src, translations in evals: | |
counts.append(len(translations)) | |
for mt in translations: | |
src_list.append(src) | |
mt_list.append(mt) | |
rewards = reward_model.reward_fn_batch(language, src_list, mt_list, session_id) | |
print("rewards: ", rewards) | |
best_translations = [] | |
index = 0 | |
for (src, translations), count in zip(evals, counts): | |
group_rewards = rewards[index: index+count] | |
index += count | |
if count < 2: | |
if translations: | |
best_translations.append((translations[0], group_rewards[0])) | |
else: | |
best_translations.append((None, None)) | |
else: | |
best_index = group_rewards.index(max(group_rewards)) | |
best_score = group_rewards[best_index] | |
best_translations.append((translations[best_index], best_score)) | |
return best_translations | |
################################# generating translation ################################# | |
# def translate_with_deepinfra(model, tokenizer, device, source_sentence, buffer, good_sent_size, src_language, tgt_language): | |
# system_prompts = [ | |
# "You are a meticulous translator. Provide a literal, word-for-word translation that preserves the structure and meaning of each individual word.", | |
# "You are a professional translator. Deliver a clear, formal, and precise translation that faithfully conveys the original meaning.", | |
# "You are a creative and expressive translator. Render the text in a vivid way, as if narrating a captivating story." | |
# ] | |
# context_prompt = f"Below is a specialized, intermediate translation task. The input text is a mix of {src_language} and partial {tgt_language} translations. " | |
# context_prompt += f"In the text, some {src_language} sentences are already followed by preliminary {tgt_language} translations enclosed in parentheses. " | |
# context_prompt += f"These provided translations are rough references – they may be incomplete, inconsistent, or not fully aligned with the original meaning.\n\n" | |
# context_prompt += f"Your task is to produce an improved {tgt_language} translation according to the following guidelines:\n" | |
# context_prompt += f"1. **Refinement:** For sections with existing {tgt_language} translations (in parentheses), refine and polish them so that they are fluent, accurate, and coherent, fully capturing the meaning of the corresponding {src_language} text.\n" | |
# context_prompt += f"2. **Completion:** For sections that remain untranslated, translate the {src_language} text accurately and naturally in the specified style.\n" | |
# context_prompt += f"3. **Translation Order and Structure Preservation:** Maintain the original order and structure of the text. Every {src_language} sentence must appear in the same sequence as in the source text, with its corresponding {tgt_language} translation (if available) inserted immediately after it. Do not rearrange or reorder any part of the text.\n" | |
# context_prompt += f"4. **Consistency:** Ensure a uniform tone and style across the entire translation, adhering to the translator role specified.\n" | |
# context_prompt += f"5. **Final Output:** Provide the final output as a single, well-structured {tgt_language} text. Do not include any extraneous commentary, explanations, annotations, or headers – output only the translation in the correct order.\n\n" | |
# context_prompt += f"Note: This translation is an intermediate version that may later be merged with other translations. Focus on clarity, coherence, and fidelity to the source text.\n" | |
# # Process the buffer to extract relevant English translations | |
# processed_source = source_sentence | |
# if len(buffer) > 0: | |
# selected_keys = random.sample(buffer.keys(), min(len(buffer), good_sent_size)) | |
# for key_sentence in selected_keys: | |
# key_sentence = key_sentence.strip() | |
# if key_sentence and (key_sentence in source_sentence) : | |
# translated_sentence = buffer[key_sentence][0][0] | |
# if f"\n({translated_sentence})\n" not in processed_source: | |
# processed_source = processed_source.replace( | |
# key_sentence, | |
# f"{key_sentence}\n({translated_sentence})\n" | |
# ) | |
# context_prompt += f"\nHere is the input data for translation:\n{processed_source}\n\n" | |
# context_prompt += "Apply the above guidelines to produce an improved, coherent translation that strictly follows the original order of the text.\n" | |
# if len(buffer) == 0: | |
# context_prompt = f"### Translate this from {src_language} to {tgt_language} and **only** output the result." | |
# context_prompt += f"\n### {src_language}:\n {source_sentence}" | |
# context_prompt += f"\n### {tgt_language}:\n" | |
# print("--------------------------------------------------------------------------------") | |
# print("\n context_prompt \n") | |
# print(context_prompt) | |
# print("--------------------------------------------------------------------------------") | |
# translations = [] | |
# for prompt in system_prompts: | |
# messages=[ | |
# {"role": "system", "content": prompt}, | |
# {"role": "user", "content": context_prompt} | |
# ] | |
# inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device) | |
# outputs = model.generate( | |
# inputs, | |
# max_new_tokens=512, | |
# temperature=0.7, | |
# top_p=0.9, | |
# do_sample=True | |
# ) | |
# translation = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True) | |
# print("--------------------------------------------------------------------------------") | |
# print("\n rollout translation: \n") | |
# print(translation) | |
# print("--------------------------------------------------------------------------------") | |
# translations.append(translation) | |
# return translations | |
def translate_with_deepinfra(model, tokenizer, device, source_sentence, buffer, good_sent_size, src_language, tgt_language): | |
system_prompts = [ | |
"You are a meticulous translator. Provide a literal, word-for-word translation that preserves the structure and meaning of each individual word.", | |
"You are a professional translator. Deliver a clear, formal, and precise translation that faithfully conveys the original meaning.", | |
"You are a creative and expressive translator. Render the text in a vivid way, as if narrating a captivating story." | |
] | |
# Process the buffer to extract relevant English translations | |
processed_source = source_sentence | |
if len(buffer) > 0: | |
selected_keys = random.sample(buffer.keys(), min(len(buffer), good_sent_size)) | |
for key_sentence in selected_keys: | |
key_sentence = key_sentence.strip() | |
if key_sentence and (key_sentence in source_sentence) : | |
translated_sentence = buffer[key_sentence][0][0] | |
if f"\n({translated_sentence})\n" not in processed_source: | |
processed_source = processed_source.replace( | |
key_sentence, | |
f"{key_sentence}\n({translated_sentence})\n" | |
) | |
translations = [] | |
for system_prompt in system_prompts: | |
if len(buffer) == 0: | |
full_prompt = ( | |
f"System: {system_prompt}\n\n" | |
f"### Translate this from {src_language} to {tgt_language}.\n" | |
f"{src_language}:\n{source_sentence}\n\n" | |
f"{tgt_language}:\n" | |
) | |
else: | |
context_prompt = ( | |
f"Below is a specialized, intermediate translation task. The input text is a mix of {src_language} and partial {tgt_language} translations. " | |
f"In the text, some {src_language} sentences are already followed by preliminary {tgt_language} translations enclosed in parentheses. " | |
f"These provided translations are rough references - they may be incomplete, inconsistent, or not fully aligned with the original meaning.\n\n" | |
f"Your task is to produce an improved {tgt_language} translation according to the following guidelines:\n" | |
f"1. Refinement: For sections with existing {tgt_language} translations (in parentheses), refine and polish them.\n" | |
f"2. Completion: For untranslated sections, translate the {src_language} text naturally.\n" | |
f"3. Translation Order: Maintain the original sequence - every source sentence must appear in order with its translation right after it.\n" | |
f"4. Consistency: Ensure a uniform tone and style.\n" | |
f"5. Output only the final {tgt_language} translation. No extra commentary.\n\n" | |
f"Note: This is an intermediate version that may later be merged. Focus on clarity and fidelity.\n\n" | |
f"Input Text:\n{processed_source}\n\n" | |
f"Assistant:" | |
) | |
full_prompt = f"System: {system_prompt}\n\n{context_prompt}" | |
print("--------------------------------------------------------------------------------") | |
print("\n full_prompt \n") | |
print(full_prompt) | |
print("--------------------------------------------------------------------------------") | |
# Tokenize and generate | |
inputs = tokenizer(full_prompt, return_tensors="pt").to(device) | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=2048, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True | |
) | |
translation = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) | |
print("--------------------------------------------------------------------------------") | |
print("\n rollout translation: \n") | |
print(translation) | |
print("--------------------------------------------------------------------------------") | |
translations.append(translation) | |
return translations | |
def process_buffer_sentences(source_sentences, buffer): | |
translations = [] | |
translation_map = {} | |
for src_key, trans_list in buffer.items(): | |
if not trans_list or not isinstance(trans_list, list): | |
continue | |
src_sentences = [src_key] | |
if len(src_sentences) > 0: | |
for src_sent in src_sentences: | |
if src_sent not in translation_map: | |
translation_map[src_sent] = [] | |
translation_map[src_sent] = trans_list[0] | |
for src_sent in source_sentences: | |
if src_sent in translation_map and translation_map[src_sent]: | |
translations.append(translation_map[src_sent][0]) | |
return translations | |
# def final_translate_with_deepinfra(model, tokenizer, device, source_sentence, source_segments, buffer, src_language, tgt_language): | |
# translations = process_buffer_sentences(source_segments, buffer) | |
# initial_translation = "\n".join(translations) | |
# rewrite_prompt = ( | |
# f"Below is an initial translation of a {src_language} text into {tgt_language}. " | |
# f"This translation may include omissions, inaccuracies, or awkward phrasing. " | |
# f"Your task is to produce a refined version that is fluent, accurate, and coherent, " | |
# f"while faithfully preserving the full meaning of the original {src_language} text.\n\n" | |
# f"### Instructions:\n" | |
# f"1. Ensure that every detail in the original {src_language} text is accurately represented.\n" | |
# f"2. Correct any grammatical errors, unnatural expressions, or inconsistencies.\n" | |
# f"3. Improve the natural flow so that the translation reads as if written by a native speaker.\n" | |
# f"4. Do not add, omit, or change any essential details from the source text.\n" | |
# f"5. Output only the final refined translation without any additional commentary.\n\n" | |
# f"### Original {src_language} Text:\n{source_sentence}\n\n" | |
# f"### Initial {tgt_language} Translation:\n{initial_translation}\n\n" | |
# f"### Refined Translation:" | |
# ) | |
# print("rewrite prompt:") | |
# print(rewrite_prompt) | |
# messages=[ | |
# {"role": "system", "content": "You are a helpful translator and only output the result."}, | |
# {"role": "user", "content": rewrite_prompt} | |
# ] | |
# inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device) | |
# outputs = model.generate( | |
# inputs, | |
# max_new_tokens=512, | |
# temperature=0.7, | |
# top_p=0.9, | |
# do_sample=True | |
# ) | |
# translation = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True) | |
# return translation | |
def final_translate_with_deepinfra(model, tokenizer, device, source_sentence, source_segments, buffer, src_language, tgt_language): | |
translations = process_buffer_sentences(source_segments, buffer) | |
initial_translation = "\n".join(translations) | |
rewrite_prompt = ( | |
f"System: You are a helpful translator and only output the result.\n\n" | |
f"Below is an initial translation of a {src_language} text into {tgt_language}. " | |
f"This translation may include omissions, inaccuracies, or awkward phrasing. " | |
f"Your task is to produce a refined version that is fluent, accurate, and coherent, " | |
f"while faithfully preserving the full meaning of the original {src_language} text.\n\n" | |
f"### Instructions:\n" | |
f"1. Ensure that every detail in the original {src_language} text is accurately represented.\n" | |
f"2. Correct any grammatical errors, unnatural expressions, or inconsistencies.\n" | |
f"3. Improve the natural flow so that the translation reads as if written by a native speaker.\n" | |
f"4. Do not add, omit, or change any essential details from the source text.\n" | |
f"5. Output only the final refined translation without any additional commentary.\n\n" | |
f"### Original {src_language} Text:\n{source_sentence}\n\n" | |
f"### Initial {tgt_language} Translation:\n{initial_translation}\n\n" | |
f"Assistant:" | |
) | |
inputs = tokenizer(rewrite_prompt, return_tensors="pt").to(device) | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=2048, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True | |
) | |
refined_translation = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) | |
return refined_translation | |
################################# alignment functions ################################# | |
def save_sentences_to_txt(sentences, filename): | |
i = 0 | |
with open(filename, "w", encoding="utf-8") as file: | |
for sentence in sentences: | |
print(sentence, i) | |
file.write(sentence + "\n") | |
i += 1 | |
def segment_sentences_by_punctuation(text, lang): | |
segmented_sentences = [] | |
paragraphs = text.split('\n') | |
for paragraph in paragraphs: | |
if paragraph.strip(): | |
if lang == src_lang: | |
doc = src_nlp(paragraph) | |
if lang == tgt_lang: | |
doc = mt_nlp(paragraph) | |
for sent in doc.sents: | |
segmented_sentences.append(sent.text.strip()) | |
return segmented_sentences | |
def generate_overlap_and_embedding(txt_file): | |
overlaps_file = txt_file + ".overlaps" | |
embed_file = txt_file + ".emb" | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
overlap_path = os.path.join(current_dir, "overlap.py") | |
subprocess.run([overlap_path, "-i", txt_file, "-o", overlaps_file, "-n", "10"]) | |
embed_command = [ | |
"$LASER/tasks/embed/embed.sh", | |
overlaps_file, | |
embed_file, | |
] | |
subprocess.run(" ".join(embed_command), shell=True) | |
return overlaps_file, embed_file | |
def run_vecalign(src_txt, tgt_txt, src_embed, tgt_embed): | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
vecalign_path = os.path.join(current_dir, "vecalign.py") | |
result = subprocess.run( | |
[ | |
"python", | |
vecalign_path, | |
"--alignment_max_size", "8", | |
"--src", src_txt, | |
"--tgt", tgt_txt, | |
"--src_embed", src_txt + ".overlaps", src_embed, | |
"--tgt_embed", tgt_txt + ".overlaps", tgt_embed, | |
], | |
stdout=subprocess.PIPE, | |
text=True, | |
) | |
alignments = [] | |
for line in result.stdout.strip().split("\n"): | |
if line: | |
src_indices, tgt_indices, _ = line.split(":") | |
src_indices = list(map(int, src_indices.strip("[]").split(","))) if src_indices.strip("[]") else [] | |
tgt_indices = list(map(int, tgt_indices.strip("[]").split(","))) if tgt_indices.strip("[]") else [] | |
alignments.append((src_indices, tgt_indices)) | |
return alignments | |
def compute_alignment_stats(alignment_results): | |
costs = [] | |
zero_cost_count = 0 | |
for entry in alignment_results: | |
try: | |
cost = float(entry.split(":")[-1]) # Extract the cost value | |
if cost == 0.0: | |
zero_cost_count += 1 | |
else: | |
costs.append(cost) | |
except ValueError: | |
continue # Ignore invalid entries | |
# Compute the average cost, ignoring zero-cost samples | |
avg_cost = sum(costs) / len(costs) if costs else 0.0 | |
zero_cost_ratio = zero_cost_count / len(alignment_results) if alignment_results else 0.0 | |
return avg_cost, zero_cost_ratio | |
def run_vecalign_explore(src_txt, tgt_txt, src_embed, tgt_embed): | |
""" | |
Runs vecalign multiple times, exploring the best del_percentile_frac. | |
Starts from 0.2 and decreases in 0.005 steps, stopping when zero-cost ratio increases sharply. | |
:param src_txt: Source text file | |
:param tgt_txt: Target text file | |
:param src_embed: Source embeddings file | |
:param tgt_embed: Target embeddings file | |
:return: (best_del_percentile_frac, best_avg_cost, best_zero_cost_ratio, best_alignments) | |
""" | |
del_percentile_frac = 0.2 # Starting value | |
step_size = 0.005 # Exploration step | |
prev_zero_cost_ratio = None | |
prev_avg_cost = None | |
best_avg_cost = float('inf') | |
best_del_percentile_frac = del_percentile_frac | |
best_zero_cost_ratio = 0.0 | |
best_alignments = [] | |
first_flag = True | |
first_zero_cost_ratio = 0.0 | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
vecalign_path = os.path.join(current_dir, "vecalign.py") | |
while del_percentile_frac > 0: | |
result = subprocess.run( | |
[ | |
"python", | |
vecalign_path, | |
"--alignment_max_size", "8", | |
"--del_percentile_frac", str(del_percentile_frac), | |
"--src", src_txt, | |
"--tgt", tgt_txt, | |
"--costs_sample_size", "200000", | |
"--search_buffer_size", "20", | |
"--src_embed", src_txt + ".overlaps", src_embed, | |
"--tgt_embed", tgt_txt + ".overlaps", tgt_embed, | |
], | |
stdout=subprocess.PIPE, | |
text=True, | |
) | |
output_lines = result.stdout.strip().split("\n") | |
avg_cost, zero_cost_ratio = compute_alignment_stats(output_lines) | |
print(f"del_percentile_frac: {del_percentile_frac:.3f} | Avg Cost: {avg_cost:.6f} | Zero-Cost Ratio: {zero_cost_ratio:.6%}") | |
if first_flag: | |
first_zero_cost_ratio = zero_cost_ratio | |
first_flag = False | |
if prev_zero_cost_ratio != 0 and prev_zero_cost_ratio is not None and (zero_cost_ratio / prev_zero_cost_ratio) > 1.5: | |
print(f"Stopping exploration: Zero-cost ratio increased sharply at {del_percentile_frac:.3f}") | |
break | |
elif prev_zero_cost_ratio is not None and ( | |
(zero_cost_ratio - prev_zero_cost_ratio) > 0.15 or | |
avg_cost > prev_avg_cost or | |
avg_cost < 0.3 or zero_cost_ratio > 0.7 | |
): | |
print(f"Stopping exploration: Zero-cost ratio increased sharply at {del_percentile_frac:.3f}") | |
break | |
else: | |
if avg_cost < best_avg_cost: | |
best_avg_cost = avg_cost | |
best_del_percentile_frac = del_percentile_frac | |
best_zero_cost_ratio = zero_cost_ratio | |
best_alignments = output_lines | |
prev_zero_cost_ratio = zero_cost_ratio | |
prev_avg_cost = avg_cost | |
del_percentile_frac -= step_size | |
final_avg_cost = best_avg_cost | |
final_zero_cost_ratio = best_zero_cost_ratio | |
final_del_percentile_frac = best_del_percentile_frac | |
final_alignments = best_alignments.copy() | |
parsed_alignments = [] | |
for line in final_alignments: | |
if line: | |
src_indices, tgt_indices, _ = line.split(":") | |
src_indices = list(map(int, src_indices.strip("[]").split(","))) if src_indices.strip("[]") else [] | |
tgt_indices = list(map(int, tgt_indices.strip("[]").split(","))) if tgt_indices.strip("[]") else [] | |
parsed_alignments.append((src_indices, tgt_indices)) | |
print("\nBest Found:") | |
print(f"del_percentile_frac: {final_del_percentile_frac:.3f} | Avg Cost: {final_avg_cost:.6f} | Zero-Cost Ratio: {final_zero_cost_ratio:.6%}") | |
return parsed_alignments | |
def standardize_common_alignments(common_alignments_list): | |
# Reference alignment for standardization (use the shortest alignment set as baseline) | |
reference_alignments = min(common_alignments_list, key=lambda alignments: len(alignments)) | |
# Standardized results to return | |
standardized_results = [] | |
for alignments in common_alignments_list: | |
standardized_alignment = [] | |
mt_idx_map = {tuple(src): mt for src, mt in alignments} | |
for src_indices, _ in reference_alignments: # Ignore ref_indices as it no longer exists | |
# If src_indices exist in the current alignment, use them directly | |
if tuple(src_indices) in mt_idx_map: | |
mt_indices = mt_idx_map[tuple(src_indices)] | |
else: | |
# If not found, merge based on src alignment | |
mt_indices = [] | |
for src in src_indices: | |
if (src,) in mt_idx_map: | |
mt_indices.extend(mt_idx_map[(src,)]) | |
# Ensure indices are unique and sorted after merging | |
mt_indices = sorted(set(mt_indices)) | |
standardized_alignment.append((src_indices, mt_indices)) | |
standardized_results.append(standardized_alignment) | |
return standardized_results | |
def generate_windows(source, translations): | |
# Segment sentences | |
source_segments = segment_sentences_by_punctuation(source, lang=src_lang) | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
temp_folder = os.path.join(current_dir, "temp") | |
os.makedirs(temp_folder, exist_ok=True) | |
# Generate overlaps and embeddings | |
src_txt = os.path.join(current_dir, f"temp/{SESSION_ID}_src.txt") | |
mt_txt = os.path.join(current_dir, f"temp/{SESSION_ID}_mt.txt") | |
print("\n ----------------- source segmentation --------------------------- ") | |
save_sentences_to_txt(source_segments, src_txt) | |
print(" ------------------------------------------------------------------- \n") | |
_, src_embed = generate_overlap_and_embedding(src_txt) | |
mt_segments_list = [segment_sentences_by_punctuation(t, lang=tgt_lang) for t in translations] | |
adjusted_mt_list = [] | |
common_alignments_list = [] | |
for mt_segments in mt_segments_list: | |
print("\n ----------------- translation segmentation --------------------------- ") | |
save_sentences_to_txt(mt_segments, mt_txt) | |
print(" ------------------------------------------------------------------------ \n") | |
_, mt_embed = generate_overlap_and_embedding(mt_txt) | |
src_mt_alignments = run_vecalign_explore(src_txt, mt_txt, src_embed, mt_embed) # run_vecalign_explore, run_vecalign | |
common_alignments_list.append(src_mt_alignments.copy()) | |
delete_files_with_mt(temp_folder) | |
common_alignments_list = standardize_common_alignments(common_alignments_list) | |
mt_index = 0 | |
for common_alignments in common_alignments_list: | |
adjusted_src = [] | |
adjusted_mt = [] | |
for src_indices, mt_indices in common_alignments: | |
mt_indices = [x for x in mt_indices if x != -1] | |
if len(src_indices) == 0: | |
continue | |
else: | |
aligned_src = " ".join([source_segments[i] for i in src_indices]) | |
if len(mt_indices) > 0: | |
aligned_mt = " ".join([mt_segments_list[mt_index][i] for i in mt_indices]) | |
else: | |
aligned_mt = "" | |
adjusted_src.append(aligned_src) | |
adjusted_mt.append(aligned_mt) | |
adjusted_mt_list.append(adjusted_mt.copy()) | |
mt_index += 1 | |
clear_folder(temp_folder, SESSION_ID) | |
return adjusted_src, adjusted_mt_list | |
################################# main function ################################# | |
def get_lang_and_nlp(language): | |
if language not in lang_map: | |
raise ValueError(f"Unsupported language: {language}") | |
lang_code, model_name = lang_map[language] | |
return lang_code, spacy.load(model_name) | |
def translate_text(text, session_id, model, tokenizer, device, | |
src_language="Japanese", | |
task_language="English", | |
max_iterations_value=3, | |
threshold_value=0.7, | |
good_ref_contexts_num_value=5, | |
reward_model_type='metricx'): | |
global SRC_LANGUAGE, TASK_LANGUAGE, max_iterations, stop_memory | |
global THRESHOLD, good_ref_contexts_num, src_lang, src_nlp, tgt_lang, mt_nlp | |
global reward_model, MEMORY_FOLDER, SESSION_ID | |
SESSION_ID = session_id | |
print("SESSION_ID: ", SESSION_ID) | |
MEMORY_FOLDER = "external_translation_memory" | |
SRC_LANGUAGE = src_language | |
TASK_LANGUAGE = task_language | |
max_iterations = max_iterations_value | |
stop_memory = list(range(1, max_iterations)) | |
THRESHOLD = threshold_value | |
good_ref_contexts_num = good_ref_contexts_num_value | |
import torch | |
device = torch.device(f"cuda:0" if torch.cuda.is_available() else "cpu") | |
src_lang, src_nlp = get_lang_and_nlp(SRC_LANGUAGE) | |
tgt_lang, mt_nlp = get_lang_and_nlp(TASK_LANGUAGE) | |
reward_model = metricx_RewardModel() | |
from collections import defaultdict | |
buffer = defaultdict(list) | |
source_sentence = text.replace("\n", " ") | |
source_segments = segment_sentences_by_punctuation(source_sentence, lang=src_lang) | |
final_translations = None | |
for iteration in range(max_iterations): | |
if iteration in stop_memory: | |
final_translations = final_translate_with_deepinfra(model, tokenizer, device, source_sentence, source_segments, buffer, SRC_LANGUAGE, TASK_LANGUAGE) | |
if iteration == max_iterations - 1: | |
break | |
else: | |
translations = translate_with_deepinfra(model, tokenizer, device, source_sentence, buffer, good_ref_contexts_num + iteration, SRC_LANGUAGE, TASK_LANGUAGE) | |
src_windows, mt_windows_list = generate_windows(source_sentence, translations) | |
# print("Evaluate translations and update buffer ..............") | |
src_context_list = list(src_windows) | |
candidates_list = [] | |
for window_index in range(len(src_windows)): | |
candidates = [mt_windows[window_index] for mt_windows in mt_windows_list] | |
candidates_list.append(candidates) | |
best_candidate_results = batch_rm_find_best_translation(list(zip(src_context_list, candidates_list)), TASK_LANGUAGE, SESSION_ID) | |
# print("\n Best candidate results:") | |
# print(best_candidate_results) | |
# print(" ------------------------------------------------------------------------\n") | |
for i, src in enumerate(src_context_list): | |
best_tuple = best_candidate_results[i] | |
if best_tuple[0] is not None: | |
if src not in buffer: | |
buffer[src] = [best_tuple] | |
# print(f"[ADD] New Source '{src}' Add Translation: '{best_tuple[0]}', Score: {best_tuple[1]}") | |
else: | |
buffer[src].append(best_tuple) | |
# print(f"[ADD] Source '{src}' Add Translation: '{best_tuple[0]}', Score: {best_tuple[1]}") | |
buffer[src].sort(key=lambda x: x[1], reverse=True) | |
# print(f"[UPDATE] Source '{src}' Best Translation: '{buffer[src][0][0]}'") | |
# print("\n===== Buffer state =====") | |
for src, translations in buffer.items(): | |
print(f"Source '{src}': {[t[0] for t in translations]}") | |
# print("Final Translation:") | |
# print(final_translations) | |
return final_translations |