Spaces:
Sleeping
Sleeping
File size: 20,591 Bytes
dd05f29 7f92284 dd05f29 946f7f8 dd05f29 7f92284 773aaec 05d3571 7f92284 dd05f29 b76263f d3c7e72 10beb7e dd05f29 3c0f52c dd05f29 2e5836c 3eb1ee6 dd05f29 0e20755 dd05f29 946f7f8 57a7224 946f7f8 dd05f29 18ff227 57a7224 dd05f29 57a7224 946f7f8 57a7224 40c048f 57a7224 dd05f29 ff6a854 6d30719 dd05f29 1abf296 5a7ce56 18ff227 dbfc6d2 18ff227 a6d920a 18ff227 03c399b dd05f29 03c399b dd05f29 5e39340 dd05f29 5e39340 dd05f29 5e39340 dd05f29 5e39340 dd05f29 5e39340 dd05f29 946f7f8 dd05f29 1d1410b 6d30719 0fbaf56 dd05f29 18ff227 1d63826 18ff227 1d63826 18ff227 1d63826 18ff227 1d63826 18ff227 1d63826 dd05f29 87d5a16 dd05f29 87d5a16 dd05f29 0e34976 dd05f29 18ff227 0e34976 18ff227 dd05f29 18ff227 dd05f29 18ff227 097ec64 dd05f29 ff6a854 18ff227 dd05f29 ff6a854 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 |
import os
import gc
import gradio as gr
import torch
import random
import logging
from huggingface_hub import login, HfApi, snapshot_download
import spacy
import subprocess
import pkg_resources
import sys
login(token=os.environ.get("LA_NAME"))
os.environ["LASER"] = "laser"
def check_and_install(package, required_version):
try:
dist = pkg_resources.get_distribution(package)
installed_version = dist.version
if installed_version != required_version:
print(f"[{package}] already installed {installed_version}. Required version {required_version},re-install...")
subprocess.check_call([sys.executable, "-m", "pip", "install", f"{package}=={required_version}", "--force-reinstall"])
else:
print(f"[{package}] required version {required_version} finished")
except pkg_resources.DistributionNotFound:
print(f"[{package}] not found, install: {required_version}...")
subprocess.check_call([sys.executable, "-m", "pip", "install", f"{package}=={required_version}"])
packages = {
"pip": "24.0",
"fairseq": "0.12.2",
"torch": "2.6.0",
"transformers": "4.51.3"
}
for package, version in packages.items():
check_and_install(package, version)
from transformers import AutoTokenizer, AutoModelForCausalLM
from vecalign.plan2align import translate_text, external_find_best_translation
from trl import AutoModelForCausalLMWithValueHead
models = ["en_core_web_sm", "ru_core_news_sm", "de_core_news_sm",
"ja_core_news_sm", "ko_core_news_sm", "es_core_news_sm"]
for model in models:
try:
spacy.load(model)
except OSError:
from spacy.cli import download
download(model)
try:
spacy.load("zh_core_web_sm")
except OSError:
from spacy.cli import download
download("zh_core_web_sm")
subprocess.check_call([sys.executable, "-m", "pip", "install", "numpy==1.24.0", "--force-reinstall"])
# ---------- translation function ----------
# Initialize device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load models once
print("Loading models...")
model_id = "google/gemma-2-9b-it" # "meta-llama/Meta-Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.float16
)
import spacy
lang_map = {
"English": ("en", "en_core_web_sm"),
"Russian": ("ru", "ru_core_news_sm"),
"German": ("de", "de_core_news_sm"),
"Japanese": ("ja", "ja_core_news_sm"),
"Korean": ("ko", "ko_core_news_sm"),
"Spanish": ("es", "es_core_news_sm"),
"Simplified Chinese": ("zh", "zh_core_web_sm"),
"Traditional Chinese": ("zh", "zh_core_web_sm")
}
def get_lang_and_nlp(language):
if language not in lang_map:
raise ValueError(f"Unsupported language: {language}")
lang_code, model_name = lang_map[language]
return lang_code, spacy.load(model_name)
def segment_sentences_by_punctuation(text, src_nlp):
segmented_sentences = []
paragraphs = text.split('\n')
for paragraph in paragraphs:
if paragraph.strip():
doc = src_nlp(paragraph)
for sent in doc.sents:
segmented_sentences.append(sent.text.strip())
return segmented_sentences
def generate_translation(system_prompt, prompt):
full_prompt = f"System: {system_prompt}\nUser: {prompt}\nAssistant:"
inputs = tokenizer(full_prompt, return_tensors="pt").to(device)
outputs = model.generate(
**inputs,
max_new_tokens=2048,
temperature=0.7,
top_p=0.9,
do_sample=True
)
translation = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
return translation
def check_token_length(text, max_tokens=1024):
return len(text) <= max_tokens
import uuid
def get_user_session(state=None):
if state is None:
state = {}
if not isinstance(state, dict):
state = {}
if not state.get("session_id"):
state["session_id"] = uuid.uuid4().hex
return state["session_id"]
# ---------- Translation Function ----------
def mpc_initial_translate(source_sentence, src_language, tgt_language):
system_prompts = [
"You are a meticulous translator. Provide a literal, word-for-word translation that preserves the structure and meaning of each individual word.",
"You are a professional translator. Deliver a clear, formal, and precise translation that faithfully conveys the original meaning.",
"You are a creative and expressive translator. Render the text in a vivid and imaginative way, as if narrating a captivating story."
]
translations = []
for prompt_style in system_prompts:
prompt = f"### Translate this from {src_language} to {tgt_language} and only output the result."
prompt += f"\n### {src_language}:\n {source_sentence}"
prompt += f"\n### {tgt_language}:\n"
translation = generate_translation(prompt_style, prompt)
translations.append(translation)
print("mpc_initial_translate")
print(translations)
return translations
def mpc_improved_translate(source_sentence, current_translation, src_language, tgt_language):
system_prompts = [
"You are a meticulous translator. Please improve the following translation by ensuring it is a literal and structurally precise version.",
"You are a professional translator. Please refine the provided translation to be clear, formal, and accurate.",
"You are a creative translator. Please enhance the translation so that it is vivid, natural, and engaging."
]
translations = []
for prompt_style in system_prompts:
prompt = (f"Source ({src_language}): {source_sentence}\n"
f"Current Translation ({tgt_language}): {current_translation}\n"
f"Please provide an improved translation into {tgt_language} and only output the result:")
translation = generate_translation(prompt_style, prompt)
translations.append(translation)
print("mpc_improved_translate")
print(translations)
return translations
def basic_translate(source_sentence, src_language, tgt_language):
system_prompts = ["You are a helpful translator and only output the result."]
translations = []
for prompt_style in system_prompts:
prompt = f"### Translate this from {src_language} to {tgt_language}."
prompt += f"\n### {src_language}:\n {source_sentence}"
prompt += f"\n### {tgt_language}:\n"
translation = generate_translation(prompt_style, prompt)
translations.append(translation)
return translations
def summary_translate(src_text, temp_tgt_text, tgt_language, session_id):
if len(temp_tgt_text.strip()) == 0:
return "", 0
system_prompts = ["You are a helpful rephraser. You only output the rephrased result."]
translations = []
for prompt_style in system_prompts:
prompt = f"### Rephrase the following in {tgt_language}."
prompt += f"\n### Input:\n {temp_tgt_text}"
prompt += f"\n### Rephrased:\n"
translation = generate_translation(prompt_style, prompt)
translations.append(translation)
try:
_, score = evaluate_candidates(src_text, translations, tgt_language, session_id)
except:
score = 0
return translations[0], score
def plan2align_translate_text(text, session_id, model, tokenizer, device, src_language, task_language, max_iterations_value, threshold_value, good_ref_contexts_num_value, reward_model_type):
result = translate_text(
text = text,
model = model,
tokenizer = tokenizer,
device = device,
src_language=src_language,
task_language=task_language,
max_iterations_value=max_iterations_value,
threshold_value=threshold_value,
good_ref_contexts_num_value=good_ref_contexts_num_value,
reward_model_type=reward_model_type,
session_id=session_id
)
try:
_, score = evaluate_candidates(text, [result], task_language, session_id)
except:
score = 0
return result, score
def evaluate_candidates(source, candidates, language, session_id):
evals = [(source, candidates)]
best_translations = external_find_best_translation(evals, language, session_id)
best_candidate, best_score = best_translations[0]
return best_candidate, best_score
def original_translation(text, src_language, target_language, session_id):
cand_list = basic_translate(text, src_language, target_language)
best, score = evaluate_candidates(text, cand_list, target_language, session_id)
if cand_list:
return best, score
return "", 0
def best_of_n_translation(text, src_language, target_language, n, session_id):
if not check_token_length(text, 4096):
return "Warning: Input text too long.", 0
candidates = []
for i in range(n):
cand_list = basic_translate(text, src_language, target_language)
if cand_list:
candidates.append(cand_list[0])
try:
best, score = evaluate_candidates(text, candidates, target_language, session_id)
print("best_of_n evaluate_candidates results:")
print(best, score)
except:
print("evaluate_candidates fail")
return "Warning: Input text too long.", 0
return best, score
def mpc_translation(text, src_language, target_language, iterations, session_id):
if not check_token_length(text, 4096):
return "Warning: Input text too long.", 0
current_trans = ""
best_score = None
for i in range(iterations):
if i == 0:
cand_list = mpc_initial_translate(text, src_language, target_language)
else:
cand_list = mpc_improved_translate(text, current_trans, src_language, target_language)
try:
best, score = evaluate_candidates(text, cand_list, target_language, session_id)
print("mpc evaluate_candidates results:")
print(best, score)
current_trans = best
best_score = score
except:
print("evaluate_candidates fail")
current_trans = cand_list[0]
best_score = 0
return current_trans, best_score
# ---------- Gradio function ----------
def process_text(text, src_language, target_language, max_iterations_value, threshold_value,
good_ref_contexts_num_value, translation_methods=None, chunk_size=-1, state=None):
translation_methods = translation_methods or ["Original", "Plan2Align"]
session_id = get_user_session(state)
"""
傳入中文文本與目標語言,依序產生四種翻譯結果:
1. 原始翻譯
2. Plan2Align 翻譯
3. Best-of-N 翻譯
4. MPC 翻譯
"""
orig_output = ""
plan2align_output = ""
best_of_n_output = ""
mpc_output = ""
src_lang, src_nlp = get_lang_and_nlp(src_language)
source_sentence = text.replace("\n", " ")
source_segments = segment_sentences_by_punctuation(source_sentence, src_nlp)
if chunk_size == -1:
if "Original" in translation_methods:
orig, best_score = original_translation(text, src_language, target_language, session_id)
orig_output = f"{orig}\n\nScore: {best_score:.2f}"
if "Plan2Align" in translation_methods:
plan2align_trans, best_score = plan2align_translate_text(
text, session_id, model, tokenizer, device, src_language, target_language,
max_iterations_value, threshold_value, good_ref_contexts_num_value, "metricx"
)
plan2align_output = f"{plan2align_trans}\n\nScore: {best_score:.2f}"
if "Best-of-N" in translation_methods:
best_candidate, best_score = best_of_n_translation(text, src_language, target_language, max_iterations_value, session_id)
best_of_n_output = f"{best_candidate}\n\nScore: {best_score:.2f}"
if "MPC" in translation_methods:
mpc_candidate, mpc_score = mpc_translation(text, src_language, target_language,
max_iterations_value, session_id)
mpc_output = f"{mpc_candidate}\n\nScore: {mpc_score:.2f}"
else:
chunks = [' '.join(source_segments[i:i+chunk_size]) for i in range(0, len(source_segments), chunk_size)]
org_translated_chunks = []
p2a_translated_chunks = []
bfn_translated_chunks = []
mpc_translated_chunks = []
for chunk in chunks:
if "Original" in translation_methods:
translation, _ = original_translation(chunk, src_language, target_language, session_id)
org_translated_chunks.append(translation)
if "Plan2Align" in translation_methods:
translation, _ = plan2align_translate_text(
chunk, session_id, model, tokenizer, device, src_language, target_language,
max_iterations_value, threshold_value, good_ref_contexts_num_value, "metricx"
)
p2a_translated_chunks.append(translation)
if "Best-of-N" in translation_methods:
translation, _ = best_of_n_translation(chunk, src_language, target_language, max_iterations_value, session_id)
bfn_translated_chunks.append(translation)
if "MPC" in translation_methods:
translation, _ = mpc_translation(chunk, src_language, target_language, max_iterations_value, session_id)
mpc_translated_chunks.append(translation)
org_combined_translation = ' '.join(org_translated_chunks)
p2a_combined_translation = ' '.join(p2a_translated_chunks)
bfn_combined_translation = ' '.join(bfn_translated_chunks)
mpc_combined_translation = ' '.join(mpc_translated_chunks)
orig, best_score = summary_translate(text, org_combined_translation, target_language, session_id)
orig_output = f"{orig}\n\nScore: {best_score:.2f}"
plan2align_trans, best_score = summary_translate(text, p2a_combined_translation, target_language, session_id)
plan2align_output = f"{plan2align_trans}\n\nScore: {best_score:.2f}"
best_candidate, best_score = summary_translate(text, bfn_combined_translation, target_language, session_id)
best_of_n_output = f"{best_candidate}\n\nScore: {best_score:.2f}"
mpc_candidate, best_score = summary_translate(text, mpc_combined_translation, target_language, session_id)
mpc_output = f"{mpc_candidate}\n\nScore: {best_score:.2f}"
return orig_output, plan2align_output, best_of_n_output, mpc_output
# ---------- Gradio ----------
target_languages = ["Traditional Chinese", "Simplified Chinese", "English", "Russian", "German", "Japanese", "Korean"]
src_languages = ["Traditional Chinese", "Simplified Chinese", "English", "Russian", "German", "Japanese", "Korean"]
with gr.Blocks(title="Test-Time Machine Translation with Plan2Align") as demo:
state = gr.State({})
gr.Markdown("# Translation Demo: Multiple Translation Methods")
gr.Markdown("請選擇要執行的翻譯方法(可多選或全選):")
with gr.Row():
with gr.Column(scale=1):
source_text = gr.Textbox(
label="Source Text",
placeholder="請輸入文本...",
lines=5
)
src_language_input = gr.Dropdown(
choices=src_languages,
value="Traditional Chinese",
label="Source Language"
)
task_language_input = gr.Dropdown(
choices=target_languages,
value="English",
label="Target Language"
)
max_iterations_input = gr.Number(label="Max Iterations", value=6)
threshold_input = gr.Number(label="Threshold", value=0.7)
good_ref_contexts_num_input = gr.Number(label="Good Ref Contexts Num", value=5)
translation_methods_input = gr.CheckboxGroup(
choices=["Original", "Plan2Align", "Best-of-N", "MPC"],
value=["Original", "Plan2Align"],
label="Translation Methods"
)
chunk_size_input = gr.Number( # ✅ add chunk function
label="Chunk Size (-1 for all)",
value=-1
)
translate_button = gr.Button("Translate")
with gr.Column(scale=2):
original_output = gr.Textbox(
label="Original Translation",
lines=5,
interactive=False
)
plan2align_output = gr.Textbox(
label="Plan2Align Translation",
lines=5,
interactive=False
)
best_of_n_output = gr.Textbox(
label="Best-of-N Translation",
lines=5,
interactive=False
)
mpc_output = gr.Textbox(
label="MPC Translation",
lines=5,
interactive=False
)
translate_button.click(
fn=process_text,
inputs=[
source_text,
src_language_input,
task_language_input,
max_iterations_input,
threshold_input,
good_ref_contexts_num_input,
translation_methods_input,
chunk_size_input, # ✅ add chunk function
state
],
outputs=[original_output, plan2align_output, best_of_n_output, mpc_output]
)
gr.Examples(
examples=[
["台灣夜市文化豐富多彩,從士林夜市到饒河街夜市,提供各種美食、遊戲和購物體驗,吸引了無數遊客。", "Traditional Chinese", "English", 2, 0.7, 1, ["Original", "Plan2Align"], -1],
["台北101曾經是世界最高的建築物,它不僅是台灣的地標,也象徵著經濟成就和創新精神。", "Traditional Chinese", "Japanese", 2, 0.7, 1, ["Original", "Plan2Align"], -1],
["阿里山日出和森林鐵路是台灣最著名的自然景觀之一,每年吸引數十萬遊客前來欣賞雲海和壯麗的日出。", "Traditional Chinese", "Korean", 2, 0.7, 1, ["Original", "Plan2Align"], -1],
# ["珍珠奶茶,這款源自台灣的獨特飲品,不僅在台灣本地深受喜愛,更以其獨特的風味和口感,在全球掀起了一股熱潮,成為了一種跨越文化、風靡全球的時尚飲品。", "Traditional Chinese", "Japanese", 3, 0.7, 3, ["Original", "Plan2Align"], -1],
# ["原住民文化如同一片深邃的星空,閃爍著無數璀璨的傳統與藝術光芒。他們的歌舞,是與祖靈對話的旋律,是與自然共鳴的節奏,每一個舞步、每一聲吟唱,都承載著古老的傳說與智慧。編織,是他們巧手下的詩篇,一絲一線,交織出生命的紋理,也編織出對土地的熱愛與敬畏。木雕,則是他們與自然對話的雕塑,每一刀、每一鑿,都刻畫著對萬物的觀察與敬意,也雕琢出對祖先的追憶與傳承。", "Traditional Chinese", "Korean", 5, 0.7, 5, ["Original", "Plan2Align"], -1]
],
inputs=[
source_text,
src_language_input,
task_language_input,
max_iterations_input,
threshold_input,
good_ref_contexts_num_input,
translation_methods_input,
chunk_size_input # ✅ add chunk function
],
outputs=[original_output, plan2align_output, best_of_n_output, mpc_output],
fn=process_text
)
gr.Markdown("## How It Works")
gr.Markdown("""
1. **Original Translation:** 利用固定提示生成候選,直接取首個候選作為原始翻譯。
2. **Plan2Align Translation:** 採用 context alignment 和 self-rewriting 策略進行翻譯,適合長文翻譯。
3. **Best-of-N Translation:** 重複生成多次候選,評分選出最佳翻譯,適合短文翻譯。
4. **MPC Translation:** 以迭代改善策略,每輪生成候選後評分,並將最佳翻譯作為下一輪輸入,適合短文翻譯。
若輸入文本超過 1024 tokens,Best-of-N 與 MPC 方法會回傳警告訊息。
""")
if __name__ == "__main__":
demo.launch(share=True)
|