P2A-test-NV / app.py
KuangDW's picture
add chunk size function
5a7ce56
import os
import gc
import gradio as gr
import torch
import random
import logging
from huggingface_hub import login, HfApi, snapshot_download
import spacy
import subprocess
import pkg_resources
import sys
login(token=os.environ.get("LA_NAME"))
os.environ["LASER"] = "laser"
def check_and_install(package, required_version):
try:
dist = pkg_resources.get_distribution(package)
installed_version = dist.version
if installed_version != required_version:
print(f"[{package}] already installed {installed_version}. Required version {required_version},re-install...")
subprocess.check_call([sys.executable, "-m", "pip", "install", f"{package}=={required_version}", "--force-reinstall"])
else:
print(f"[{package}] required version {required_version} finished")
except pkg_resources.DistributionNotFound:
print(f"[{package}] not found, install: {required_version}...")
subprocess.check_call([sys.executable, "-m", "pip", "install", f"{package}=={required_version}"])
packages = {
"pip": "24.0",
"fairseq": "0.12.2",
"torch": "2.6.0",
"transformers": "4.51.3"
}
for package, version in packages.items():
check_and_install(package, version)
from transformers import AutoTokenizer, AutoModelForCausalLM
from vecalign.plan2align import translate_text, external_find_best_translation
from trl import AutoModelForCausalLMWithValueHead
models = ["en_core_web_sm", "ru_core_news_sm", "de_core_news_sm",
"ja_core_news_sm", "ko_core_news_sm", "es_core_news_sm"]
for model in models:
try:
spacy.load(model)
except OSError:
from spacy.cli import download
download(model)
try:
spacy.load("zh_core_web_sm")
except OSError:
from spacy.cli import download
download("zh_core_web_sm")
subprocess.check_call([sys.executable, "-m", "pip", "install", "numpy==1.24.0", "--force-reinstall"])
# ---------- translation function ----------
# Initialize device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load models once
print("Loading models...")
model_id = "google/gemma-2-9b-it" # "meta-llama/Meta-Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.float16
)
import spacy
lang_map = {
"English": ("en", "en_core_web_sm"),
"Russian": ("ru", "ru_core_news_sm"),
"German": ("de", "de_core_news_sm"),
"Japanese": ("ja", "ja_core_news_sm"),
"Korean": ("ko", "ko_core_news_sm"),
"Spanish": ("es", "es_core_news_sm"),
"Simplified Chinese": ("zh", "zh_core_web_sm"),
"Traditional Chinese": ("zh", "zh_core_web_sm")
}
def get_lang_and_nlp(language):
if language not in lang_map:
raise ValueError(f"Unsupported language: {language}")
lang_code, model_name = lang_map[language]
return lang_code, spacy.load(model_name)
def segment_sentences_by_punctuation(text, src_nlp):
segmented_sentences = []
paragraphs = text.split('\n')
for paragraph in paragraphs:
if paragraph.strip():
doc = src_nlp(paragraph)
for sent in doc.sents:
segmented_sentences.append(sent.text.strip())
return segmented_sentences
def generate_translation(system_prompt, prompt):
full_prompt = f"System: {system_prompt}\nUser: {prompt}\nAssistant:"
inputs = tokenizer(full_prompt, return_tensors="pt").to(device)
outputs = model.generate(
**inputs,
max_new_tokens=2048,
temperature=0.7,
top_p=0.9,
do_sample=True
)
translation = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
return translation
def check_token_length(text, max_tokens=1024):
return len(text) <= max_tokens
import uuid
def get_user_session(state=None):
if state is None:
state = {}
if not isinstance(state, dict):
state = {}
if not state.get("session_id"):
state["session_id"] = uuid.uuid4().hex
return state["session_id"]
# ---------- Translation Function ----------
def mpc_initial_translate(source_sentence, src_language, tgt_language):
system_prompts = [
"You are a meticulous translator. Provide a literal, word-for-word translation that preserves the structure and meaning of each individual word.",
"You are a professional translator. Deliver a clear, formal, and precise translation that faithfully conveys the original meaning.",
"You are a creative and expressive translator. Render the text in a vivid and imaginative way, as if narrating a captivating story."
]
translations = []
for prompt_style in system_prompts:
prompt = f"### Translate this from {src_language} to {tgt_language} and only output the result."
prompt += f"\n### {src_language}:\n {source_sentence}"
prompt += f"\n### {tgt_language}:\n"
translation = generate_translation(prompt_style, prompt)
translations.append(translation)
print("mpc_initial_translate")
print(translations)
return translations
def mpc_improved_translate(source_sentence, current_translation, src_language, tgt_language):
system_prompts = [
"You are a meticulous translator. Please improve the following translation by ensuring it is a literal and structurally precise version.",
"You are a professional translator. Please refine the provided translation to be clear, formal, and accurate.",
"You are a creative translator. Please enhance the translation so that it is vivid, natural, and engaging."
]
translations = []
for prompt_style in system_prompts:
prompt = (f"Source ({src_language}): {source_sentence}\n"
f"Current Translation ({tgt_language}): {current_translation}\n"
f"Please provide an improved translation into {tgt_language} and only output the result:")
translation = generate_translation(prompt_style, prompt)
translations.append(translation)
print("mpc_improved_translate")
print(translations)
return translations
def basic_translate(source_sentence, src_language, tgt_language):
system_prompts = ["You are a helpful translator and only output the result."]
translations = []
for prompt_style in system_prompts:
prompt = f"### Translate this from {src_language} to {tgt_language}."
prompt += f"\n### {src_language}:\n {source_sentence}"
prompt += f"\n### {tgt_language}:\n"
translation = generate_translation(prompt_style, prompt)
translations.append(translation)
return translations
def summary_translate(src_text, temp_tgt_text, tgt_language, session_id):
if len(temp_tgt_text.strip()) == 0:
return "", 0
system_prompts = ["You are a helpful rephraser. You only output the rephrased result."]
translations = []
for prompt_style in system_prompts:
prompt = f"### Rephrase the following in {tgt_language}."
prompt += f"\n### Input:\n {temp_tgt_text}"
prompt += f"\n### Rephrased:\n"
translation = generate_translation(prompt_style, prompt)
translations.append(translation)
try:
_, score = evaluate_candidates(src_text, translations, tgt_language, session_id)
except:
score = 0
return translations[0], score
def plan2align_translate_text(text, session_id, model, tokenizer, device, src_language, task_language, max_iterations_value, threshold_value, good_ref_contexts_num_value, reward_model_type):
result = translate_text(
text = text,
model = model,
tokenizer = tokenizer,
device = device,
src_language=src_language,
task_language=task_language,
max_iterations_value=max_iterations_value,
threshold_value=threshold_value,
good_ref_contexts_num_value=good_ref_contexts_num_value,
reward_model_type=reward_model_type,
session_id=session_id
)
try:
_, score = evaluate_candidates(text, [result], task_language, session_id)
except:
score = 0
return result, score
def evaluate_candidates(source, candidates, language, session_id):
evals = [(source, candidates)]
best_translations = external_find_best_translation(evals, language, session_id)
best_candidate, best_score = best_translations[0]
return best_candidate, best_score
def original_translation(text, src_language, target_language, session_id):
cand_list = basic_translate(text, src_language, target_language)
best, score = evaluate_candidates(text, cand_list, target_language, session_id)
if cand_list:
return best, score
return "", 0
def best_of_n_translation(text, src_language, target_language, n, session_id):
if not check_token_length(text, 4096):
return "Warning: Input text too long.", 0
candidates = []
for i in range(n):
cand_list = basic_translate(text, src_language, target_language)
if cand_list:
candidates.append(cand_list[0])
try:
best, score = evaluate_candidates(text, candidates, target_language, session_id)
print("best_of_n evaluate_candidates results:")
print(best, score)
except:
print("evaluate_candidates fail")
return "Warning: Input text too long.", 0
return best, score
def mpc_translation(text, src_language, target_language, iterations, session_id):
if not check_token_length(text, 4096):
return "Warning: Input text too long.", 0
current_trans = ""
best_score = None
for i in range(iterations):
if i == 0:
cand_list = mpc_initial_translate(text, src_language, target_language)
else:
cand_list = mpc_improved_translate(text, current_trans, src_language, target_language)
try:
best, score = evaluate_candidates(text, cand_list, target_language, session_id)
print("mpc evaluate_candidates results:")
print(best, score)
current_trans = best
best_score = score
except:
print("evaluate_candidates fail")
current_trans = cand_list[0]
best_score = 0
return current_trans, best_score
# ---------- Gradio function ----------
def process_text(text, src_language, target_language, max_iterations_value, threshold_value,
good_ref_contexts_num_value, translation_methods=None, chunk_size=-1, state=None):
translation_methods = translation_methods or ["Original", "Plan2Align"]
session_id = get_user_session(state)
"""
傳入中文文本與目標語言,依序產生四種翻譯結果:
1. 原始翻譯
2. Plan2Align 翻譯
3. Best-of-N 翻譯
4. MPC 翻譯
"""
orig_output = ""
plan2align_output = ""
best_of_n_output = ""
mpc_output = ""
src_lang, src_nlp = get_lang_and_nlp(src_language)
source_sentence = text.replace("\n", " ")
source_segments = segment_sentences_by_punctuation(source_sentence, src_nlp)
if chunk_size == -1:
if "Original" in translation_methods:
orig, best_score = original_translation(text, src_language, target_language, session_id)
orig_output = f"{orig}\n\nScore: {best_score:.2f}"
if "Plan2Align" in translation_methods:
plan2align_trans, best_score = plan2align_translate_text(
text, session_id, model, tokenizer, device, src_language, target_language,
max_iterations_value, threshold_value, good_ref_contexts_num_value, "metricx"
)
plan2align_output = f"{plan2align_trans}\n\nScore: {best_score:.2f}"
if "Best-of-N" in translation_methods:
best_candidate, best_score = best_of_n_translation(text, src_language, target_language, max_iterations_value, session_id)
best_of_n_output = f"{best_candidate}\n\nScore: {best_score:.2f}"
if "MPC" in translation_methods:
mpc_candidate, mpc_score = mpc_translation(text, src_language, target_language,
max_iterations_value, session_id)
mpc_output = f"{mpc_candidate}\n\nScore: {mpc_score:.2f}"
else:
chunks = [' '.join(source_segments[i:i+chunk_size]) for i in range(0, len(source_segments), chunk_size)]
org_translated_chunks = []
p2a_translated_chunks = []
bfn_translated_chunks = []
mpc_translated_chunks = []
for chunk in chunks:
if "Original" in translation_methods:
translation, _ = original_translation(chunk, src_language, target_language, session_id)
org_translated_chunks.append(translation)
if "Plan2Align" in translation_methods:
translation, _ = plan2align_translate_text(
chunk, session_id, model, tokenizer, device, src_language, target_language,
max_iterations_value, threshold_value, good_ref_contexts_num_value, "metricx"
)
p2a_translated_chunks.append(translation)
if "Best-of-N" in translation_methods:
translation, _ = best_of_n_translation(chunk, src_language, target_language, max_iterations_value, session_id)
bfn_translated_chunks.append(translation)
if "MPC" in translation_methods:
translation, _ = mpc_translation(chunk, src_language, target_language, max_iterations_value, session_id)
mpc_translated_chunks.append(translation)
org_combined_translation = ' '.join(org_translated_chunks)
p2a_combined_translation = ' '.join(p2a_translated_chunks)
bfn_combined_translation = ' '.join(bfn_translated_chunks)
mpc_combined_translation = ' '.join(mpc_translated_chunks)
orig, best_score = summary_translate(text, org_combined_translation, target_language, session_id)
orig_output = f"{orig}\n\nScore: {best_score:.2f}"
plan2align_trans, best_score = summary_translate(text, p2a_combined_translation, target_language, session_id)
plan2align_output = f"{plan2align_trans}\n\nScore: {best_score:.2f}"
best_candidate, best_score = summary_translate(text, bfn_combined_translation, target_language, session_id)
best_of_n_output = f"{best_candidate}\n\nScore: {best_score:.2f}"
mpc_candidate, best_score = summary_translate(text, mpc_combined_translation, target_language, session_id)
mpc_output = f"{mpc_candidate}\n\nScore: {best_score:.2f}"
return orig_output, plan2align_output, best_of_n_output, mpc_output
# ---------- Gradio ----------
target_languages = ["Traditional Chinese", "Simplified Chinese", "English", "Russian", "German", "Japanese", "Korean"]
src_languages = ["Traditional Chinese", "Simplified Chinese", "English", "Russian", "German", "Japanese", "Korean"]
with gr.Blocks(title="Test-Time Machine Translation with Plan2Align") as demo:
state = gr.State({})
gr.Markdown("# Translation Demo: Multiple Translation Methods")
gr.Markdown("請選擇要執行的翻譯方法(可多選或全選):")
with gr.Row():
with gr.Column(scale=1):
source_text = gr.Textbox(
label="Source Text",
placeholder="請輸入文本...",
lines=5
)
src_language_input = gr.Dropdown(
choices=src_languages,
value="Traditional Chinese",
label="Source Language"
)
task_language_input = gr.Dropdown(
choices=target_languages,
value="English",
label="Target Language"
)
max_iterations_input = gr.Number(label="Max Iterations", value=6)
threshold_input = gr.Number(label="Threshold", value=0.7)
good_ref_contexts_num_input = gr.Number(label="Good Ref Contexts Num", value=5)
translation_methods_input = gr.CheckboxGroup(
choices=["Original", "Plan2Align", "Best-of-N", "MPC"],
value=["Original", "Plan2Align"],
label="Translation Methods"
)
chunk_size_input = gr.Number( # ✅ add chunk function
label="Chunk Size (-1 for all)",
value=-1
)
translate_button = gr.Button("Translate")
with gr.Column(scale=2):
original_output = gr.Textbox(
label="Original Translation",
lines=5,
interactive=False
)
plan2align_output = gr.Textbox(
label="Plan2Align Translation",
lines=5,
interactive=False
)
best_of_n_output = gr.Textbox(
label="Best-of-N Translation",
lines=5,
interactive=False
)
mpc_output = gr.Textbox(
label="MPC Translation",
lines=5,
interactive=False
)
translate_button.click(
fn=process_text,
inputs=[
source_text,
src_language_input,
task_language_input,
max_iterations_input,
threshold_input,
good_ref_contexts_num_input,
translation_methods_input,
chunk_size_input, # ✅ add chunk function
state
],
outputs=[original_output, plan2align_output, best_of_n_output, mpc_output]
)
gr.Examples(
examples=[
["台灣夜市文化豐富多彩,從士林夜市到饒河街夜市,提供各種美食、遊戲和購物體驗,吸引了無數遊客。", "Traditional Chinese", "English", 2, 0.7, 1, ["Original", "Plan2Align"], -1],
["台北101曾經是世界最高的建築物,它不僅是台灣的地標,也象徵著經濟成就和創新精神。", "Traditional Chinese", "Japanese", 2, 0.7, 1, ["Original", "Plan2Align"], -1],
["阿里山日出和森林鐵路是台灣最著名的自然景觀之一,每年吸引數十萬遊客前來欣賞雲海和壯麗的日出。", "Traditional Chinese", "Korean", 2, 0.7, 1, ["Original", "Plan2Align"], -1],
# ["珍珠奶茶,這款源自台灣的獨特飲品,不僅在台灣本地深受喜愛,更以其獨特的風味和口感,在全球掀起了一股熱潮,成為了一種跨越文化、風靡全球的時尚飲品。", "Traditional Chinese", "Japanese", 3, 0.7, 3, ["Original", "Plan2Align"], -1],
# ["原住民文化如同一片深邃的星空,閃爍著無數璀璨的傳統與藝術光芒。他們的歌舞,是與祖靈對話的旋律,是與自然共鳴的節奏,每一個舞步、每一聲吟唱,都承載著古老的傳說與智慧。編織,是他們巧手下的詩篇,一絲一線,交織出生命的紋理,也編織出對土地的熱愛與敬畏。木雕,則是他們與自然對話的雕塑,每一刀、每一鑿,都刻畫著對萬物的觀察與敬意,也雕琢出對祖先的追憶與傳承。", "Traditional Chinese", "Korean", 5, 0.7, 5, ["Original", "Plan2Align"], -1]
],
inputs=[
source_text,
src_language_input,
task_language_input,
max_iterations_input,
threshold_input,
good_ref_contexts_num_input,
translation_methods_input,
chunk_size_input # ✅ add chunk function
],
outputs=[original_output, plan2align_output, best_of_n_output, mpc_output],
fn=process_text
)
gr.Markdown("## How It Works")
gr.Markdown("""
1. **Original Translation:** 利用固定提示生成候選,直接取首個候選作為原始翻譯。
2. **Plan2Align Translation:** 採用 context alignment 和 self-rewriting 策略進行翻譯,適合長文翻譯。
3. **Best-of-N Translation:** 重複生成多次候選,評分選出最佳翻譯,適合短文翻譯。
4. **MPC Translation:** 以迭代改善策略,每輪生成候選後評分,並將最佳翻譯作為下一輪輸入,適合短文翻譯。
若輸入文本超過 1024 tokens,Best-of-N 與 MPC 方法會回傳警告訊息。
""")
if __name__ == "__main__":
demo.launch(share=True)