File size: 20,591 Bytes
dd05f29
 
7f92284
dd05f29
 
 
946f7f8
dd05f29
 
 
 
7f92284
773aaec
05d3571
7f92284
dd05f29
 
 
 
 
 
 
 
 
 
 
 
 
 
b76263f
d3c7e72
10beb7e
dd05f29
3c0f52c
dd05f29
 
 
2e5836c
3eb1ee6
 
dd05f29
 
 
 
 
 
 
 
 
 
 
 
 
 
0e20755
dd05f29
946f7f8
 
 
 
 
 
 
57a7224
946f7f8
 
 
 
 
dd05f29
 
18ff227
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57a7224
dd05f29
57a7224
 
946f7f8
57a7224
40c048f
57a7224
 
 
 
 
dd05f29
 
 
 
 
 
ff6a854
 
 
 
6d30719
dd05f29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1abf296
5a7ce56
 
 
18ff227
 
 
 
dbfc6d2
18ff227
 
 
 
a6d920a
 
 
 
 
 
18ff227
03c399b
dd05f29
03c399b
 
 
 
dd05f29
 
 
 
 
 
 
 
5e39340
 
 
 
dd05f29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e39340
 
dd05f29
 
 
 
 
5e39340
 
 
 
 
 
 
dd05f29
 
 
5e39340
 
dd05f29
 
 
 
 
 
 
 
5e39340
 
 
 
 
 
 
 
 
 
 
dd05f29
 
946f7f8
dd05f29
 
1d1410b
6d30719
0fbaf56
dd05f29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18ff227
 
 
 
 
 
1d63826
 
18ff227
1d63826
 
18ff227
 
1d63826
18ff227
1d63826
 
18ff227
1d63826
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd05f29
 
 
 
87d5a16
 
dd05f29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87d5a16
dd05f29
 
 
 
 
0e34976
dd05f29
 
 
 
 
 
 
 
 
18ff227
0e34976
18ff227
 
dd05f29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18ff227
dd05f29
 
 
 
 
 
 
18ff227
097ec64
 
 
 
dd05f29
 
 
 
 
 
 
ff6a854
18ff227
 
dd05f29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff6a854
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
import os
import gc
import gradio as gr
import torch
import random
import logging
from huggingface_hub import login, HfApi, snapshot_download
import spacy
import subprocess
import pkg_resources
import sys

login(token=os.environ.get("LA_NAME"))
os.environ["LASER"] = "laser"

def check_and_install(package, required_version):
    try:
        dist = pkg_resources.get_distribution(package)
        installed_version = dist.version
        if installed_version != required_version:
            print(f"[{package}] already installed {installed_version}. Required version  {required_version},re-install...")
            subprocess.check_call([sys.executable, "-m", "pip", "install", f"{package}=={required_version}", "--force-reinstall"])
        else:
            print(f"[{package}] required version  {required_version} finished")
    except pkg_resources.DistributionNotFound:
        print(f"[{package}] not found, install: {required_version}...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", f"{package}=={required_version}"])
packages = {
    "pip": "24.0",
    "fairseq": "0.12.2",
    "torch": "2.6.0",
    "transformers": "4.51.3"
}

for package, version in packages.items():
    check_and_install(package, version)

from transformers import AutoTokenizer, AutoModelForCausalLM
from vecalign.plan2align import translate_text, external_find_best_translation
from trl import AutoModelForCausalLMWithValueHead

models = ["en_core_web_sm", "ru_core_news_sm", "de_core_news_sm", 
          "ja_core_news_sm", "ko_core_news_sm", "es_core_news_sm"]
for model in models:
    try:
        spacy.load(model)
    except OSError:
        from spacy.cli import download
        download(model)
try:
    spacy.load("zh_core_web_sm")
except OSError:
    from spacy.cli import download
    download("zh_core_web_sm")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "numpy==1.24.0", "--force-reinstall"])

# ---------- translation function ----------

# Initialize device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load models once
print("Loading models...")
model_id = "google/gemma-2-9b-it" # "meta-llama/Meta-Llama-3.1-8B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id, 
    device_map="auto",
    torch_dtype=torch.float16
)

import spacy
lang_map = {
    "English": ("en", "en_core_web_sm"),
    "Russian": ("ru", "ru_core_news_sm"),
    "German": ("de", "de_core_news_sm"),
    "Japanese": ("ja", "ja_core_news_sm"),
    "Korean": ("ko", "ko_core_news_sm"),
    "Spanish": ("es", "es_core_news_sm"),
    "Simplified Chinese": ("zh", "zh_core_web_sm"),
    "Traditional Chinese": ("zh", "zh_core_web_sm")
}

def get_lang_and_nlp(language):
    if language not in lang_map:
        raise ValueError(f"Unsupported language: {language}")
    lang_code, model_name = lang_map[language]
    return lang_code, spacy.load(model_name)

def segment_sentences_by_punctuation(text, src_nlp):
    segmented_sentences = []
    paragraphs = text.split('\n')
    for paragraph in paragraphs:
        if paragraph.strip():
            doc = src_nlp(paragraph)
            for sent in doc.sents:
                segmented_sentences.append(sent.text.strip())
    return segmented_sentences

def generate_translation(system_prompt, prompt):
    full_prompt = f"System: {system_prompt}\nUser: {prompt}\nAssistant:"
    inputs = tokenizer(full_prompt, return_tensors="pt").to(device)
    outputs = model.generate(
        **inputs,
        max_new_tokens=2048,
        temperature=0.7,
        top_p=0.9,
        do_sample=True
    )
    translation = tokenizer.decode(outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True)
    return translation

def check_token_length(text, max_tokens=1024):
    return len(text) <= max_tokens

import uuid
def get_user_session(state=None):
    if state is None:
        state = {}
    if not isinstance(state, dict):
        state = {}
    if not state.get("session_id"):
        state["session_id"] = uuid.uuid4().hex
    return state["session_id"]

# ---------- Translation Function ----------

def mpc_initial_translate(source_sentence, src_language, tgt_language):
    system_prompts = [
        "You are a meticulous translator. Provide a literal, word-for-word translation that preserves the structure and meaning of each individual word.",
        "You are a professional translator. Deliver a clear, formal, and precise translation that faithfully conveys the original meaning.",
        "You are a creative and expressive translator. Render the text in a vivid and imaginative way, as if narrating a captivating story."
    ]
    translations = []
    for prompt_style in system_prompts:
        prompt = f"### Translate this from {src_language} to {tgt_language} and only output the result."
        prompt += f"\n### {src_language}:\n {source_sentence}"
        prompt += f"\n### {tgt_language}:\n"
        translation = generate_translation(prompt_style, prompt)
        translations.append(translation)
    
    print("mpc_initial_translate")
    print(translations)
    return translations

def mpc_improved_translate(source_sentence, current_translation, src_language, tgt_language):
    system_prompts = [
        "You are a meticulous translator. Please improve the following translation by ensuring it is a literal and structurally precise version.",
        "You are a professional translator. Please refine the provided translation to be clear, formal, and accurate.",
        "You are a creative translator. Please enhance the translation so that it is vivid, natural, and engaging."
    ]
    translations = []
    for prompt_style in system_prompts:
        prompt = (f"Source ({src_language}): {source_sentence}\n"
                  f"Current Translation ({tgt_language}): {current_translation}\n"
                  f"Please provide an improved translation into {tgt_language} and only output the result:")
        translation = generate_translation(prompt_style, prompt)
        translations.append(translation)
    
    print("mpc_improved_translate")
    print(translations)
    return translations

def basic_translate(source_sentence, src_language, tgt_language):
    system_prompts = ["You are a helpful translator and only output the result."]
    translations = []
    for prompt_style in system_prompts:
        prompt = f"### Translate this from {src_language} to {tgt_language}."
        prompt += f"\n### {src_language}:\n {source_sentence}"
        prompt += f"\n### {tgt_language}:\n"
        translation = generate_translation(prompt_style, prompt)
        translations.append(translation)
    return translations 

def summary_translate(src_text, temp_tgt_text, tgt_language, session_id):
    if len(temp_tgt_text.strip()) == 0:
        return "", 0
    
    system_prompts = ["You are a helpful rephraser. You only output the rephrased result."]
    translations = []
    for prompt_style in system_prompts:
        prompt = f"### Rephrase the following in {tgt_language}."
        prompt += f"\n### Input:\n {temp_tgt_text}"
        prompt += f"\n### Rephrased:\n"
        translation = generate_translation(prompt_style, prompt)
        translations.append(translation)

    try:
        _, score = evaluate_candidates(src_text, translations, tgt_language, session_id)
    except:
        score = 0

    return translations[0], score

def plan2align_translate_text(text, session_id, model, tokenizer, device, src_language, task_language, max_iterations_value, threshold_value, good_ref_contexts_num_value, reward_model_type):
    result = translate_text(
        text = text,
        model = model,
        tokenizer = tokenizer,
        device = device,
        src_language=src_language,
        task_language=task_language,
        max_iterations_value=max_iterations_value,
        threshold_value=threshold_value,
        good_ref_contexts_num_value=good_ref_contexts_num_value,
        reward_model_type=reward_model_type,
        session_id=session_id
    )
    try:
        _, score = evaluate_candidates(text, [result], task_language, session_id)
    except:
        score = 0
    return result, score

def evaluate_candidates(source, candidates, language, session_id):
    evals = [(source, candidates)]
    best_translations = external_find_best_translation(evals, language, session_id)
    best_candidate, best_score = best_translations[0]
    return best_candidate, best_score

def original_translation(text, src_language, target_language, session_id):
    cand_list = basic_translate(text, src_language, target_language)
    best, score = evaluate_candidates(text, cand_list, target_language, session_id)
    if cand_list:
        return best, score
    return "", 0

def best_of_n_translation(text, src_language, target_language, n, session_id):
    if not check_token_length(text, 4096):
         return "Warning: Input text too long.", 0
    candidates = []
    for i in range(n):
        cand_list = basic_translate(text, src_language, target_language)
        if cand_list:
            candidates.append(cand_list[0])
    try:
        best, score = evaluate_candidates(text, candidates, target_language, session_id)
        print("best_of_n evaluate_candidates results:")
        print(best, score)
    except:
        print("evaluate_candidates fail")
        return "Warning: Input text too long.", 0
    return best, score

def mpc_translation(text, src_language, target_language, iterations, session_id):
    if not check_token_length(text, 4096):
         return "Warning: Input text too long.", 0
    current_trans = ""
    best_score = None
    for i in range(iterations):
        if i == 0:
            cand_list = mpc_initial_translate(text, src_language, target_language)
        else:
            cand_list = mpc_improved_translate(text, current_trans, src_language, target_language)        
        
        try:
            best, score = evaluate_candidates(text, cand_list, target_language, session_id)
            print("mpc evaluate_candidates results:")
            print(best, score)
            current_trans = best
            best_score = score
        except:
            print("evaluate_candidates fail")
            current_trans = cand_list[0]
            best_score = 0

    return current_trans, best_score

# ---------- Gradio function ----------

def process_text(text, src_language, target_language, max_iterations_value, threshold_value,
                 good_ref_contexts_num_value, translation_methods=None, chunk_size=-1, state=None):
    
    translation_methods = translation_methods or ["Original", "Plan2Align"]
    session_id = get_user_session(state)

    """
    傳入中文文本與目標語言,依序產生四種翻譯結果:
      1. 原始翻譯
      2. Plan2Align 翻譯
      3. Best-of-N 翻譯
      4. MPC 翻譯
    """

    orig_output = ""
    plan2align_output = ""
    best_of_n_output = ""
    mpc_output = ""

    src_lang, src_nlp = get_lang_and_nlp(src_language)
    source_sentence = text.replace("\n", " ")
    source_segments = segment_sentences_by_punctuation(source_sentence, src_nlp)
    
    if chunk_size == -1:
        if "Original" in translation_methods:
            orig, best_score = original_translation(text, src_language, target_language, session_id)
            orig_output = f"{orig}\n\nScore: {best_score:.2f}"
        if "Plan2Align" in translation_methods:
            plan2align_trans, best_score = plan2align_translate_text(
                text, session_id, model, tokenizer, device, src_language, target_language,
                max_iterations_value, threshold_value, good_ref_contexts_num_value, "metricx"
            )
            plan2align_output = f"{plan2align_trans}\n\nScore: {best_score:.2f}"
        if "Best-of-N" in translation_methods:
            best_candidate, best_score = best_of_n_translation(text, src_language, target_language, max_iterations_value, session_id)
            best_of_n_output = f"{best_candidate}\n\nScore: {best_score:.2f}"
        if "MPC" in translation_methods:
            mpc_candidate, mpc_score = mpc_translation(text, src_language, target_language,
                                                       max_iterations_value, session_id)
            mpc_output = f"{mpc_candidate}\n\nScore: {mpc_score:.2f}"
    else:
        chunks = [' '.join(source_segments[i:i+chunk_size]) for i in range(0, len(source_segments), chunk_size)]

        org_translated_chunks = []
        p2a_translated_chunks = []
        bfn_translated_chunks = []
        mpc_translated_chunks = []

        for chunk in chunks:
            if "Original" in translation_methods:
                translation, _ = original_translation(chunk, src_language, target_language, session_id)
                org_translated_chunks.append(translation)
            if "Plan2Align" in translation_methods:
                translation, _ = plan2align_translate_text(
                    chunk, session_id, model, tokenizer, device, src_language, target_language,
                    max_iterations_value, threshold_value, good_ref_contexts_num_value, "metricx"
                )
                p2a_translated_chunks.append(translation)
            if "Best-of-N" in translation_methods:
                translation, _ = best_of_n_translation(chunk, src_language, target_language, max_iterations_value, session_id)
                bfn_translated_chunks.append(translation)
            if "MPC" in translation_methods:
                translation, _ = mpc_translation(chunk, src_language, target_language, max_iterations_value, session_id)
                mpc_translated_chunks.append(translation)
        
        org_combined_translation = ' '.join(org_translated_chunks)
        p2a_combined_translation = ' '.join(p2a_translated_chunks)
        bfn_combined_translation = ' '.join(bfn_translated_chunks)
        mpc_combined_translation = ' '.join(mpc_translated_chunks)

        orig, best_score = summary_translate(text, org_combined_translation, target_language, session_id)
        orig_output = f"{orig}\n\nScore: {best_score:.2f}"
        plan2align_trans, best_score = summary_translate(text, p2a_combined_translation, target_language, session_id)
        plan2align_output = f"{plan2align_trans}\n\nScore: {best_score:.2f}"
        best_candidate, best_score = summary_translate(text, bfn_combined_translation, target_language, session_id)
        best_of_n_output = f"{best_candidate}\n\nScore: {best_score:.2f}"
        mpc_candidate, best_score = summary_translate(text, mpc_combined_translation, target_language, session_id)
        mpc_output = f"{mpc_candidate}\n\nScore: {best_score:.2f}"
    
    return orig_output, plan2align_output, best_of_n_output, mpc_output

# ---------- Gradio ----------
target_languages = ["Traditional Chinese", "Simplified Chinese", "English", "Russian", "German", "Japanese", "Korean"]
src_languages = ["Traditional Chinese", "Simplified Chinese", "English", "Russian", "German", "Japanese", "Korean"]

with gr.Blocks(title="Test-Time Machine Translation with Plan2Align") as demo:
    state = gr.State({})

    gr.Markdown("# Translation Demo: Multiple Translation Methods")
    gr.Markdown("請選擇要執行的翻譯方法(可多選或全選):")

    with gr.Row():
        with gr.Column(scale=1):
            source_text = gr.Textbox(
                label="Source Text",
                placeholder="請輸入文本...",
                lines=5
            )
            src_language_input = gr.Dropdown(
                choices=src_languages,
                value="Traditional Chinese",
                label="Source Language"
            )
            task_language_input = gr.Dropdown(
                choices=target_languages,
                value="English",
                label="Target Language"
            )
            max_iterations_input = gr.Number(label="Max Iterations", value=6)
            threshold_input = gr.Number(label="Threshold", value=0.7)
            good_ref_contexts_num_input = gr.Number(label="Good Ref Contexts Num", value=5)
            translation_methods_input = gr.CheckboxGroup(
                choices=["Original", "Plan2Align", "Best-of-N", "MPC"],
                value=["Original", "Plan2Align"],
                label="Translation Methods"
            )
            chunk_size_input = gr.Number(  # ✅ add chunk function
                label="Chunk Size (-1 for all)",
                value=-1
            )
            translate_button = gr.Button("Translate")
        with gr.Column(scale=2):
            original_output = gr.Textbox(
                label="Original Translation",
                lines=5,
                interactive=False
            )
            plan2align_output = gr.Textbox(
                label="Plan2Align Translation",
                lines=5,
                interactive=False
            )
            best_of_n_output = gr.Textbox(
                label="Best-of-N Translation",
                lines=5,
                interactive=False
            )
            mpc_output = gr.Textbox(
                label="MPC Translation",
                lines=5,
                interactive=False
            )
    
    translate_button.click(
        fn=process_text,
        inputs=[
            source_text,
            src_language_input,
            task_language_input,
            max_iterations_input,
            threshold_input,
            good_ref_contexts_num_input,
            translation_methods_input,
            chunk_size_input,   # ✅ add chunk function
            state
        ],
        outputs=[original_output, plan2align_output, best_of_n_output, mpc_output]
    )
    
    gr.Examples(
        examples=[
            ["台灣夜市文化豐富多彩,從士林夜市到饒河街夜市,提供各種美食、遊戲和購物體驗,吸引了無數遊客。", "Traditional Chinese", "English", 2, 0.7, 1, ["Original", "Plan2Align"], -1],
            ["台北101曾經是世界最高的建築物,它不僅是台灣的地標,也象徵著經濟成就和創新精神。", "Traditional Chinese", "Japanese", 2, 0.7, 1, ["Original", "Plan2Align"], -1],
            ["阿里山日出和森林鐵路是台灣最著名的自然景觀之一,每年吸引數十萬遊客前來欣賞雲海和壯麗的日出。", "Traditional Chinese", "Korean", 2, 0.7, 1, ["Original", "Plan2Align"], -1],
            # ["珍珠奶茶,這款源自台灣的獨特飲品,不僅在台灣本地深受喜愛,更以其獨特的風味和口感,在全球掀起了一股熱潮,成為了一種跨越文化、風靡全球的時尚飲品。", "Traditional Chinese", "Japanese", 3, 0.7, 3, ["Original", "Plan2Align"], -1],
            # ["原住民文化如同一片深邃的星空,閃爍著無數璀璨的傳統與藝術光芒。他們的歌舞,是與祖靈對話的旋律,是與自然共鳴的節奏,每一個舞步、每一聲吟唱,都承載著古老的傳說與智慧。編織,是他們巧手下的詩篇,一絲一線,交織出生命的紋理,也編織出對土地的熱愛與敬畏。木雕,則是他們與自然對話的雕塑,每一刀、每一鑿,都刻畫著對萬物的觀察與敬意,也雕琢出對祖先的追憶與傳承。", "Traditional Chinese", "Korean", 5, 0.7, 5, ["Original", "Plan2Align"], -1]
        ],
        inputs=[
            source_text,
            src_language_input,
            task_language_input,
            max_iterations_input,
            threshold_input,
            good_ref_contexts_num_input,
            translation_methods_input,
            chunk_size_input  # ✅ add chunk function
        ],
        outputs=[original_output, plan2align_output, best_of_n_output, mpc_output],
        fn=process_text
    )
    
    gr.Markdown("## How It Works")
    gr.Markdown("""
    1. **Original Translation:** 利用固定提示生成候選,直接取首個候選作為原始翻譯。
    2. **Plan2Align Translation:** 採用 context alignment 和 self-rewriting 策略進行翻譯,適合長文翻譯。
    3. **Best-of-N Translation:** 重複生成多次候選,評分選出最佳翻譯,適合短文翻譯。
    4. **MPC Translation:** 以迭代改善策略,每輪生成候選後評分,並將最佳翻譯作為下一輪輸入,適合短文翻譯。
    
    若輸入文本超過 1024 tokens,Best-of-N 與 MPC 方法會回傳警告訊息。
    """)

if __name__ == "__main__":
    demo.launch(share=True)