Spaces:
Sleeping
Sleeping
import os | |
import gc | |
import gradio as gr | |
import torch | |
import random | |
import logging | |
from huggingface_hub import login, HfApi, snapshot_download | |
import spacy | |
import subprocess | |
import pkg_resources | |
import sys | |
login(token=os.environ.get("LA_NAME")) | |
os.environ["LASER"] = "laser" | |
def check_and_install(package, required_version): | |
try: | |
dist = pkg_resources.get_distribution(package) | |
installed_version = dist.version | |
if installed_version != required_version: | |
print(f"[{package}] already installed {installed_version}. Required version {required_version},re-install...") | |
subprocess.check_call([sys.executable, "-m", "pip", "install", f"{package}=={required_version}", "--force-reinstall"]) | |
else: | |
print(f"[{package}] required version {required_version} finished") | |
except pkg_resources.DistributionNotFound: | |
print(f"[{package}] not found, install: {required_version}...") | |
subprocess.check_call([sys.executable, "-m", "pip", "install", f"{package}=={required_version}"]) | |
packages = { | |
"pip": "24.0", | |
"fairseq": "0.12.2", | |
"torch": "2.6.0", | |
"transformers": "4.51.3" | |
} | |
for package, version in packages.items(): | |
check_and_install(package, version) | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
from vecalign.plan2align import translate_text, external_find_best_translation | |
from trl import AutoModelForCausalLMWithValueHead | |
models = ["en_core_web_sm", "ru_core_news_sm", "de_core_news_sm", | |
"ja_core_news_sm", "ko_core_news_sm", "es_core_news_sm"] | |
for model in models: | |
try: | |
spacy.load(model) | |
except OSError: | |
from spacy.cli import download | |
download(model) | |
try: | |
spacy.load("zh_core_web_sm") | |
except OSError: | |
from spacy.cli import download | |
download("zh_core_web_sm") | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "numpy==1.24.0", "--force-reinstall"]) | |
# ---------- translation function ---------- | |
# Initialize device | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
# Load models once | |
print("Loading models...") | |
model_id = "google/gemma-2-9b-it" # "meta-llama/Meta-Llama-3.1-8B-Instruct" | |
tokenizer = AutoTokenizer.from_pretrained(model_id) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
device_map="auto", | |
torch_dtype=torch.float16 | |
) | |
import spacy | |
lang_map = { | |
"English": ("en", "en_core_web_sm"), | |
"Russian": ("ru", "ru_core_news_sm"), | |
"German": ("de", "de_core_news_sm"), | |
"Japanese": ("ja", "ja_core_news_sm"), | |
"Korean": ("ko", "ko_core_news_sm"), | |
"Spanish": ("es", "es_core_news_sm"), | |
"Simplified Chinese": ("zh", "zh_core_web_sm"), | |
"Traditional Chinese": ("zh", "zh_core_web_sm") | |
} | |
def get_lang_and_nlp(language): | |
if language not in lang_map: | |
raise ValueError(f"Unsupported language: {language}") | |
lang_code, model_name = lang_map[language] | |
return lang_code, spacy.load(model_name) | |
def segment_sentences_by_punctuation(text, src_nlp): | |
segmented_sentences = [] | |
paragraphs = text.split('\n') | |
for paragraph in paragraphs: | |
if paragraph.strip(): | |
doc = src_nlp(paragraph) | |
for sent in doc.sents: | |
segmented_sentences.append(sent.text.strip()) | |
return segmented_sentences | |
def generate_translation(system_prompt, prompt): | |
full_prompt = f"System: {system_prompt}\nUser: {prompt}\nAssistant:" | |
inputs = tokenizer(full_prompt, return_tensors="pt").to(device) | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=2048, | |
temperature=0.7, | |
top_p=0.9, | |
do_sample=True | |
) | |
translation = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True) | |
return translation | |
def check_token_length(text, max_tokens=1024): | |
return len(text) <= max_tokens | |
import uuid | |
def get_user_session(state=None): | |
if state is None: | |
state = {} | |
if not isinstance(state, dict): | |
state = {} | |
if not state.get("session_id"): | |
state["session_id"] = uuid.uuid4().hex | |
return state["session_id"] | |
# ---------- Translation Function ---------- | |
def mpc_initial_translate(source_sentence, src_language, tgt_language): | |
system_prompts = [ | |
"You are a meticulous translator. Provide a literal, word-for-word translation that preserves the structure and meaning of each individual word.", | |
"You are a professional translator. Deliver a clear, formal, and precise translation that faithfully conveys the original meaning.", | |
"You are a creative and expressive translator. Render the text in a vivid and imaginative way, as if narrating a captivating story." | |
] | |
translations = [] | |
for prompt_style in system_prompts: | |
prompt = f"### Translate this from {src_language} to {tgt_language} and only output the result." | |
prompt += f"\n### {src_language}:\n {source_sentence}" | |
prompt += f"\n### {tgt_language}:\n" | |
translation = generate_translation(prompt_style, prompt) | |
translations.append(translation) | |
print("mpc_initial_translate") | |
print(translations) | |
return translations | |
def mpc_improved_translate(source_sentence, current_translation, src_language, tgt_language): | |
system_prompts = [ | |
"You are a meticulous translator. Please improve the following translation by ensuring it is a literal and structurally precise version.", | |
"You are a professional translator. Please refine the provided translation to be clear, formal, and accurate.", | |
"You are a creative translator. Please enhance the translation so that it is vivid, natural, and engaging." | |
] | |
translations = [] | |
for prompt_style in system_prompts: | |
prompt = (f"Source ({src_language}): {source_sentence}\n" | |
f"Current Translation ({tgt_language}): {current_translation}\n" | |
f"Please provide an improved translation into {tgt_language} and only output the result:") | |
translation = generate_translation(prompt_style, prompt) | |
translations.append(translation) | |
print("mpc_improved_translate") | |
print(translations) | |
return translations | |
def basic_translate(source_sentence, src_language, tgt_language): | |
system_prompts = ["You are a helpful translator and only output the result."] | |
translations = [] | |
for prompt_style in system_prompts: | |
prompt = f"### Translate this from {src_language} to {tgt_language}." | |
prompt += f"\n### {src_language}:\n {source_sentence}" | |
prompt += f"\n### {tgt_language}:\n" | |
translation = generate_translation(prompt_style, prompt) | |
translations.append(translation) | |
return translations | |
def summary_translate(src_text, temp_tgt_text, tgt_language, session_id): | |
if len(temp_tgt_text.strip()) == 0: | |
return "", 0 | |
system_prompts = ["You are a helpful rephraser. You only output the rephrased result."] | |
translations = [] | |
for prompt_style in system_prompts: | |
prompt = f"### Rephrase the following in {tgt_language}." | |
prompt += f"\n### Input:\n {temp_tgt_text}" | |
prompt += f"\n### Rephrased:\n" | |
translation = generate_translation(prompt_style, prompt) | |
translations.append(translation) | |
try: | |
_, score = evaluate_candidates(src_text, translations, tgt_language, session_id) | |
except: | |
score = 0 | |
return translations[0], score | |
def plan2align_translate_text(text, session_id, model, tokenizer, device, src_language, task_language, max_iterations_value, threshold_value, good_ref_contexts_num_value, reward_model_type): | |
result = translate_text( | |
text = text, | |
model = model, | |
tokenizer = tokenizer, | |
device = device, | |
src_language=src_language, | |
task_language=task_language, | |
max_iterations_value=max_iterations_value, | |
threshold_value=threshold_value, | |
good_ref_contexts_num_value=good_ref_contexts_num_value, | |
reward_model_type=reward_model_type, | |
session_id=session_id | |
) | |
try: | |
_, score = evaluate_candidates(text, [result], task_language, session_id) | |
except: | |
score = 0 | |
return result, score | |
def evaluate_candidates(source, candidates, language, session_id): | |
evals = [(source, candidates)] | |
best_translations = external_find_best_translation(evals, language, session_id) | |
best_candidate, best_score = best_translations[0] | |
return best_candidate, best_score | |
def original_translation(text, src_language, target_language, session_id): | |
cand_list = basic_translate(text, src_language, target_language) | |
best, score = evaluate_candidates(text, cand_list, target_language, session_id) | |
if cand_list: | |
return best, score | |
return "", 0 | |
def best_of_n_translation(text, src_language, target_language, n, session_id): | |
if not check_token_length(text, 4096): | |
return "Warning: Input text too long.", 0 | |
candidates = [] | |
for i in range(n): | |
cand_list = basic_translate(text, src_language, target_language) | |
if cand_list: | |
candidates.append(cand_list[0]) | |
try: | |
best, score = evaluate_candidates(text, candidates, target_language, session_id) | |
print("best_of_n evaluate_candidates results:") | |
print(best, score) | |
except: | |
print("evaluate_candidates fail") | |
return "Warning: Input text too long.", 0 | |
return best, score | |
def mpc_translation(text, src_language, target_language, iterations, session_id): | |
if not check_token_length(text, 4096): | |
return "Warning: Input text too long.", 0 | |
current_trans = "" | |
best_score = None | |
for i in range(iterations): | |
if i == 0: | |
cand_list = mpc_initial_translate(text, src_language, target_language) | |
else: | |
cand_list = mpc_improved_translate(text, current_trans, src_language, target_language) | |
try: | |
best, score = evaluate_candidates(text, cand_list, target_language, session_id) | |
print("mpc evaluate_candidates results:") | |
print(best, score) | |
current_trans = best | |
best_score = score | |
except: | |
print("evaluate_candidates fail") | |
current_trans = cand_list[0] | |
best_score = 0 | |
return current_trans, best_score | |
# ---------- Gradio function ---------- | |
def process_text(text, src_language, target_language, max_iterations_value, threshold_value, | |
good_ref_contexts_num_value, translation_methods=None, chunk_size=-1, state=None): | |
translation_methods = translation_methods or ["Original", "Plan2Align"] | |
session_id = get_user_session(state) | |
""" | |
傳入中文文本與目標語言,依序產生四種翻譯結果: | |
1. 原始翻譯 | |
2. Plan2Align 翻譯 | |
3. Best-of-N 翻譯 | |
4. MPC 翻譯 | |
""" | |
orig_output = "" | |
plan2align_output = "" | |
best_of_n_output = "" | |
mpc_output = "" | |
src_lang, src_nlp = get_lang_and_nlp(src_language) | |
source_sentence = text.replace("\n", " ") | |
source_segments = segment_sentences_by_punctuation(source_sentence, src_nlp) | |
if chunk_size == -1: | |
if "Original" in translation_methods: | |
orig, best_score = original_translation(text, src_language, target_language, session_id) | |
orig_output = f"{orig}\n\nScore: {best_score:.2f}" | |
if "Plan2Align" in translation_methods: | |
plan2align_trans, best_score = plan2align_translate_text( | |
text, session_id, model, tokenizer, device, src_language, target_language, | |
max_iterations_value, threshold_value, good_ref_contexts_num_value, "metricx" | |
) | |
plan2align_output = f"{plan2align_trans}\n\nScore: {best_score:.2f}" | |
if "Best-of-N" in translation_methods: | |
best_candidate, best_score = best_of_n_translation(text, src_language, target_language, max_iterations_value, session_id) | |
best_of_n_output = f"{best_candidate}\n\nScore: {best_score:.2f}" | |
if "MPC" in translation_methods: | |
mpc_candidate, mpc_score = mpc_translation(text, src_language, target_language, | |
max_iterations_value, session_id) | |
mpc_output = f"{mpc_candidate}\n\nScore: {mpc_score:.2f}" | |
else: | |
chunks = [' '.join(source_segments[i:i+chunk_size]) for i in range(0, len(source_segments), chunk_size)] | |
org_translated_chunks = [] | |
p2a_translated_chunks = [] | |
bfn_translated_chunks = [] | |
mpc_translated_chunks = [] | |
for chunk in chunks: | |
if "Original" in translation_methods: | |
translation, _ = original_translation(chunk, src_language, target_language, session_id) | |
org_translated_chunks.append(translation) | |
if "Plan2Align" in translation_methods: | |
translation, _ = plan2align_translate_text( | |
chunk, session_id, model, tokenizer, device, src_language, target_language, | |
max_iterations_value, threshold_value, good_ref_contexts_num_value, "metricx" | |
) | |
p2a_translated_chunks.append(translation) | |
if "Best-of-N" in translation_methods: | |
translation, _ = best_of_n_translation(chunk, src_language, target_language, max_iterations_value, session_id) | |
bfn_translated_chunks.append(translation) | |
if "MPC" in translation_methods: | |
translation, _ = mpc_translation(chunk, src_language, target_language, max_iterations_value, session_id) | |
mpc_translated_chunks.append(translation) | |
org_combined_translation = ' '.join(org_translated_chunks) | |
p2a_combined_translation = ' '.join(p2a_translated_chunks) | |
bfn_combined_translation = ' '.join(bfn_translated_chunks) | |
mpc_combined_translation = ' '.join(mpc_translated_chunks) | |
orig, best_score = summary_translate(text, org_combined_translation, target_language, session_id) | |
orig_output = f"{orig}\n\nScore: {best_score:.2f}" | |
plan2align_trans, best_score = summary_translate(text, p2a_combined_translation, target_language, session_id) | |
plan2align_output = f"{plan2align_trans}\n\nScore: {best_score:.2f}" | |
best_candidate, best_score = summary_translate(text, bfn_combined_translation, target_language, session_id) | |
best_of_n_output = f"{best_candidate}\n\nScore: {best_score:.2f}" | |
mpc_candidate, best_score = summary_translate(text, mpc_combined_translation, target_language, session_id) | |
mpc_output = f"{mpc_candidate}\n\nScore: {best_score:.2f}" | |
return orig_output, plan2align_output, best_of_n_output, mpc_output | |
# ---------- Gradio ---------- | |
target_languages = ["Traditional Chinese", "Simplified Chinese", "English", "Russian", "German", "Japanese", "Korean"] | |
src_languages = ["Traditional Chinese", "Simplified Chinese", "English", "Russian", "German", "Japanese", "Korean"] | |
with gr.Blocks(title="Test-Time Machine Translation with Plan2Align") as demo: | |
state = gr.State({}) | |
gr.Markdown("# Translation Demo: Multiple Translation Methods") | |
gr.Markdown("請選擇要執行的翻譯方法(可多選或全選):") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
source_text = gr.Textbox( | |
label="Source Text", | |
placeholder="請輸入文本...", | |
lines=5 | |
) | |
src_language_input = gr.Dropdown( | |
choices=src_languages, | |
value="Traditional Chinese", | |
label="Source Language" | |
) | |
task_language_input = gr.Dropdown( | |
choices=target_languages, | |
value="English", | |
label="Target Language" | |
) | |
max_iterations_input = gr.Number(label="Max Iterations", value=6) | |
threshold_input = gr.Number(label="Threshold", value=0.7) | |
good_ref_contexts_num_input = gr.Number(label="Good Ref Contexts Num", value=5) | |
translation_methods_input = gr.CheckboxGroup( | |
choices=["Original", "Plan2Align", "Best-of-N", "MPC"], | |
value=["Original", "Plan2Align"], | |
label="Translation Methods" | |
) | |
chunk_size_input = gr.Number( # ✅ add chunk function | |
label="Chunk Size (-1 for all)", | |
value=-1 | |
) | |
translate_button = gr.Button("Translate") | |
with gr.Column(scale=2): | |
original_output = gr.Textbox( | |
label="Original Translation", | |
lines=5, | |
interactive=False | |
) | |
plan2align_output = gr.Textbox( | |
label="Plan2Align Translation", | |
lines=5, | |
interactive=False | |
) | |
best_of_n_output = gr.Textbox( | |
label="Best-of-N Translation", | |
lines=5, | |
interactive=False | |
) | |
mpc_output = gr.Textbox( | |
label="MPC Translation", | |
lines=5, | |
interactive=False | |
) | |
translate_button.click( | |
fn=process_text, | |
inputs=[ | |
source_text, | |
src_language_input, | |
task_language_input, | |
max_iterations_input, | |
threshold_input, | |
good_ref_contexts_num_input, | |
translation_methods_input, | |
chunk_size_input, # ✅ add chunk function | |
state | |
], | |
outputs=[original_output, plan2align_output, best_of_n_output, mpc_output] | |
) | |
gr.Examples( | |
examples=[ | |
["台灣夜市文化豐富多彩,從士林夜市到饒河街夜市,提供各種美食、遊戲和購物體驗,吸引了無數遊客。", "Traditional Chinese", "English", 2, 0.7, 1, ["Original", "Plan2Align"], -1], | |
["台北101曾經是世界最高的建築物,它不僅是台灣的地標,也象徵著經濟成就和創新精神。", "Traditional Chinese", "Japanese", 2, 0.7, 1, ["Original", "Plan2Align"], -1], | |
["阿里山日出和森林鐵路是台灣最著名的自然景觀之一,每年吸引數十萬遊客前來欣賞雲海和壯麗的日出。", "Traditional Chinese", "Korean", 2, 0.7, 1, ["Original", "Plan2Align"], -1], | |
# ["珍珠奶茶,這款源自台灣的獨特飲品,不僅在台灣本地深受喜愛,更以其獨特的風味和口感,在全球掀起了一股熱潮,成為了一種跨越文化、風靡全球的時尚飲品。", "Traditional Chinese", "Japanese", 3, 0.7, 3, ["Original", "Plan2Align"], -1], | |
# ["原住民文化如同一片深邃的星空,閃爍著無數璀璨的傳統與藝術光芒。他們的歌舞,是與祖靈對話的旋律,是與自然共鳴的節奏,每一個舞步、每一聲吟唱,都承載著古老的傳說與智慧。編織,是他們巧手下的詩篇,一絲一線,交織出生命的紋理,也編織出對土地的熱愛與敬畏。木雕,則是他們與自然對話的雕塑,每一刀、每一鑿,都刻畫著對萬物的觀察與敬意,也雕琢出對祖先的追憶與傳承。", "Traditional Chinese", "Korean", 5, 0.7, 5, ["Original", "Plan2Align"], -1] | |
], | |
inputs=[ | |
source_text, | |
src_language_input, | |
task_language_input, | |
max_iterations_input, | |
threshold_input, | |
good_ref_contexts_num_input, | |
translation_methods_input, | |
chunk_size_input # ✅ add chunk function | |
], | |
outputs=[original_output, plan2align_output, best_of_n_output, mpc_output], | |
fn=process_text | |
) | |
gr.Markdown("## How It Works") | |
gr.Markdown(""" | |
1. **Original Translation:** 利用固定提示生成候選,直接取首個候選作為原始翻譯。 | |
2. **Plan2Align Translation:** 採用 context alignment 和 self-rewriting 策略進行翻譯,適合長文翻譯。 | |
3. **Best-of-N Translation:** 重複生成多次候選,評分選出最佳翻譯,適合短文翻譯。 | |
4. **MPC Translation:** 以迭代改善策略,每輪生成候選後評分,並將最佳翻譯作為下一輪輸入,適合短文翻譯。 | |
若輸入文本超過 1024 tokens,Best-of-N 與 MPC 方法會回傳警告訊息。 | |
""") | |
if __name__ == "__main__": | |
demo.launch(share=True) | |