Plan2Align-NV / app.py
KuangDW
update before pull
a042d39
raw
history blame
15.5 kB
import os
import gc
import gradio as gr
import torch
import random
import logging
import openai
from openai import OpenAI
from vecalign.plan2align import translate_text, external_find_best_translation
from transformers import AutoTokenizer, AutoModelForCausalLM
from trl import AutoModelForCausalLMWithValueHead
from huggingface_hub import login, HfApi, snapshot_download
import spacy
import subprocess
import pkg_resources
import sys
login(token=os.environ.get("LA_NAME"))
laser_token = os.environ.get("ENC")
laser_path = snapshot_download(repo_id="KuangDW/laser", use_auth_token=laser_token)
os.environ["LASER"] = laser_path
def check_and_install(package, required_version):
try:
dist = pkg_resources.get_distribution(package)
installed_version = dist.version
if installed_version != required_version:
print(f"[{package}] already installed {installed_version}. Required version {required_version},re-install...")
subprocess.check_call([sys.executable, "-m", "pip", "install", f"{package}=={required_version}", "--force-reinstall"])
else:
print(f"[{package}] required version {required_version} finished")
except pkg_resources.DistributionNotFound:
print(f"[{package}] not found, install: {required_version}...")
subprocess.check_call([sys.executable, "-m", "pip", "install", f"{package}=={required_version}"])
packages = {
"pip": "24.0",
"fairseq": "0.12.2"
}
for package, version in packages.items():
check_and_install(package, version)
models = ["en_core_web_sm", "ru_core_news_sm", "de_core_news_sm",
"ja_core_news_sm", "ko_core_news_sm", "es_core_news_sm"]
for model in models:
try:
spacy.load(model)
except OSError:
from spacy.cli import download
download(model)
try:
spacy.load("zh_core_web_sm")
except OSError:
from spacy.cli import download
download("zh_core_web_sm")
subprocess.check_call([sys.executable, "-m", "pip", "install", "numpy==1.24.0", "--force-reinstall"])
# ---------- translation function ----------
# Initialize device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load models once
print("Loading models...")
model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.float16
)
def generate_translation(system_prompt, prompt):
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt}
]
inputs = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)
outputs = model.generate(
inputs,
max_new_tokens=512,
temperature=0.7,
top_p=0.9,
do_sample=True
)
translation = tokenizer.decode(outputs[0][inputs.shape[1]:], skip_special_tokens=True)
return translation
def check_token_length(text, max_tokens=1024):
return len(text) <= max_tokens
import uuid
def get_user_session(state):
if not state:
state = {}
if not state.get("session_id"):
state["session_id"] = uuid.uuid4().hex
return state["session_id"]
# ---------- Translation Function ----------
def mpc_initial_translate(source_sentence, src_language, tgt_language):
system_prompts = [
"You are a meticulous translator. Provide a literal, word-for-word translation that preserves the structure and meaning of each individual word.",
"You are a professional translator. Deliver a clear, formal, and precise translation that faithfully conveys the original meaning.",
"You are a creative and expressive translator. Render the text in a vivid and imaginative way, as if narrating a captivating story."
]
translations = []
for prompt_style in system_prompts:
prompt = f"### Translate this from {src_language} to {tgt_language} and only output the result."
prompt += f"\n### {src_language}:\n {source_sentence}"
prompt += f"\n### {tgt_language}:\n"
translation = generate_translation(prompt_style, prompt)
translations.append(translation)
print("mpc_initial_translate")
print(translations)
return translations
def mpc_improved_translate(source_sentence, current_translation, src_language, tgt_language):
system_prompts = [
"You are a meticulous translator. Please improve the following translation by ensuring it is a literal and structurally precise version.",
"You are a professional translator. Please refine the provided translation to be clear, formal, and accurate.",
"You are a creative translator. Please enhance the translation so that it is vivid, natural, and engaging."
]
translations = []
for prompt_style in system_prompts:
prompt = (f"Source ({src_language}): {source_sentence}\n"
f"Current Translation ({tgt_language}): {current_translation}\n"
f"Please provide an improved translation into {tgt_language} and only output the result:")
translation = generate_translation(prompt_style, prompt)
translations.append(translation)
print("mpc_improved_translate")
print(translations)
return translations
def basic_translate(source_sentence, src_language, tgt_language):
system_prompts = ["You are a helpful translator and only output the result."]
translations = []
for prompt_style in system_prompts:
prompt = f"### Translate this from {src_language} to {tgt_language}."
prompt += f"\n### {src_language}:\n {source_sentence}"
prompt += f"\n### {tgt_language}:\n"
translation = generate_translation(prompt_style, prompt)
translations.append(translation)
return translations
def plan2align_translate_text(text, session_id, model, tokenizer, device, src_language, task_language, max_iterations_value, threshold_value, good_ref_contexts_num_value, reward_model_type):
result = translate_text(
text = text,
model = model,
tokenizer = tokenizer,
device = device,
src_language=src_language,
task_language=task_language,
max_iterations_value=max_iterations_value,
threshold_value=threshold_value,
good_ref_contexts_num_value=good_ref_contexts_num_value,
reward_model_type=reward_model_type,
session_id=session_id
)
_, score = evaluate_candidates(text, [result], task_language, session_id)
return result, score
def evaluate_candidates(source, candidates, language, session_id):
evals = [(source, candidates)]
best_translations = external_find_best_translation(evals, language, session_id)
best_candidate, best_score = best_translations[0]
return best_candidate, best_score
def original_translation(text, src_language, target_language, session_id):
cand_list = basic_translate(text, src_language, target_language)
best, score = evaluate_candidates(text, cand_list, target_language, session_id)
if cand_list:
return best, score
return "", 0
def best_of_n_translation(text, src_language, target_language, n, session_id):
if not check_token_length(text, 2048):
return "Warning: Input text exceeds 2048 tokens.", None, ""
candidates = []
for i in range(n):
cand_list = basic_translate(text, src_language, target_language)
if cand_list:
candidates.append(cand_list[0])
best, score = evaluate_candidates(text, candidates, target_language, session_id)
print("best_of_n evaluate_candidates results:")
print(best, score)
return best, score
def mpc_translation(text, src_language, target_language, iterations, session_id):
if not check_token_length(text, 2048):
return "Warning: Input text exceeds 2048 tokens.", None, ""
current_trans = ""
best_score = None
for i in range(iterations):
if i == 0:
cand_list = mpc_initial_translate(text, src_language, target_language)
else:
cand_list = mpc_improved_translate(text, current_trans, src_language, target_language)
best, score = evaluate_candidates(text, cand_list, target_language, session_id)
print("mpc evaluate_candidates results:")
print(best, score)
current_trans = best
best_score = score
return current_trans, best_score
# ---------- Gradio function ----------
def process_text(text, src_language, target_language, max_iterations_value, threshold_value,
good_ref_contexts_num_value, translation_methods=None, state=None):
translation_methods = translation_methods or ["Original", "Plan2Align"]
state = state or {}
session_id = get_user_session(state)
"""
傳入中文文本與目標語言,依序產生四種翻譯結果:
1. 原始翻譯
2. Plan2Align 翻譯
3. Best-of-N 翻譯
4. MPC 翻譯
"""
orig_output = ""
plan2align_output = ""
best_of_n_output = ""
mpc_output = ""
if "Original" in translation_methods:
orig, best_score = original_translation(text, src_language, target_language, session_id)
orig_output = f"{orig}\n\nScore: {best_score:.2f}"
if "Plan2Align" in translation_methods:
plan2align_trans, best_score = plan2align_translate_text(
text, session_id, model, tokenizer, device, src_language, target_language,
max_iterations_value, threshold_value, good_ref_contexts_num_value, "metricx"
)
plan2align_output = f"{plan2align_trans}\n\nScore: {best_score:.2f}"
if "Best-of-N" in translation_methods:
best_candidate, best_score = best_of_n_translation(text, src_language, target_language,
max_iterations_value, session_id)
best_of_n_output = f"{best_candidate}\n\nScore: {best_score:.2f}"
if "MPC" in translation_methods:
mpc_candidate, mpc_score = mpc_translation(text, src_language, target_language,
max_iterations_value, session_id)
mpc_output = f"{mpc_candidate}\n\nScore: {mpc_score:.2f}"
return orig_output, plan2align_output, best_of_n_output, mpc_output
# ---------- Gradio ----------
target_languages = ["Chinese", "English", "Russian", "German", "Japanese", "Korean"]
src_languages = ["Chinese", "English", "Russian", "German", "Japanese", "Korean"]
with gr.Blocks(title="Test-Time Machine Translation with Plan2Align") as demo:
state = gr.State({})
gr.Markdown("# Translation Demo: Multiple Translation Methods")
gr.Markdown("請選擇要執行的翻譯方法(可多選或全選):")
with gr.Row():
with gr.Column(scale=1):
source_text = gr.Textbox(
label="Source Text",
placeholder="請輸入文本...",
lines=5
)
src_language_input = gr.Dropdown(
choices=src_languages,
value="Chinese",
label="Source Language"
)
task_language_input = gr.Dropdown(
choices=target_languages,
value="English",
label="Task (Target) Language"
)
max_iterations_input = gr.Number(label="Max Iterations", value=6)
threshold_input = gr.Number(label="Threshold", value=0.7)
good_ref_contexts_num_input = gr.Number(label="Good Ref Contexts Num", value=5)
translation_methods_input = gr.CheckboxGroup(
choices=["Original", "Plan2Align", "Best-of-N", "MPC"],
value=["Original", "Plan2Align"],
label="Translation Methods"
)
translate_button = gr.Button("Translate")
with gr.Column(scale=2):
original_output = gr.Textbox(
label="Original Translation",
lines=5,
interactive=False
)
plan2align_output = gr.Textbox(
label="Plan2Align Translation",
lines=5,
interactive=False
)
best_of_n_output = gr.Textbox(
label="Best-of-N Translation",
lines=5,
interactive=False
)
mpc_output = gr.Textbox(
label="MPC Translation",
lines=5,
interactive=False
)
translate_button.click(
fn=process_text,
inputs=[
source_text,
src_language_input,
task_language_input,
max_iterations_input,
threshold_input,
good_ref_contexts_num_input,
translation_methods_input,
state
],
outputs=[original_output, plan2align_output, best_of_n_output, mpc_output]
)
gr.Examples(
examples=[
["台灣夜市文化豐富多彩,從士林夜市到饒河街夜市,提供各種美食、遊戲和購物體驗,吸引了無數遊客。", "Chinese", "English", 2, 0.7, 1],
["台北101曾經是世界最高的建築物,它不僅是台灣的地標,也象徵著經濟成就和創新精神。", "Chinese", "Russian", 2, 0.7, 1],
["阿里山日出和森林鐵路是台灣最著名的自然景觀之一,每年吸引數十萬遊客前來欣賞雲海和壯麗的日出。", "Chinese", "German", 2, 0.7, 1],
["珍珠奶茶,這款源自台灣的獨特飲品,不僅在台灣本地深受喜愛,更以其獨特的風味和口感,在全球掀起了一股熱潮,成為了一種跨越文化、風靡全球的時尚飲品。", "Chinese", "Japanese", 3, 0.7, 3],
["原住民文化如同一片深邃的星空,閃爍著無數璀璨的傳統與藝術光芒。他們的歌舞,是與祖靈對話的旋律,是與自然共鳴的節奏,每一個舞步、每一聲吟唱,都承載著古老的傳說與智慧。編織,是他們巧手下的詩篇,一絲一線,交織出生命的紋理,也編織出對土地的熱愛與敬畏。木雕,則是他們與自然對話的雕塑,每一刀、每一鑿,都刻畫著對萬物的觀察與敬意,也雕琢出對祖先的追憶與傳承。", "Chinese", "Korean", 5, 0.7, 5]
],
inputs=[
source_text,
src_language_input,
task_language_input,
max_iterations_input,
threshold_input,
good_ref_contexts_num_input
],
outputs=[original_output, plan2align_output, best_of_n_output, mpc_output],
fn=process_text
)
gr.Markdown("## How It Works")
gr.Markdown("""
1. **Original Translation:** 利用固定提示生成候選,直接取首個候選作為原始翻譯。
2. **Plan2Align Translation:** 採用 context alignment 和 self-rewriting 策略進行翻譯,適合長文翻譯。
3. **Best-of-N Translation:** 重複生成多次候選,評分選出最佳翻譯,適合短文翻譯。
4. **MPC Translation:** 以迭代改善策略,每輪生成候選後評分,並將最佳翻譯作為下一輪輸入,適合短文翻譯。
若輸入文本超過 1024 tokens,Best-of-N 與 MPC 方法會回傳警告訊息。
""")
if __name__ == "__main__":
demo.launch()