File size: 3,094 Bytes
0fe65c0
fe9a8c3
 
 
0fe65c0
fe9a8c3
 
 
 
 
 
0fe65c0
fe9a8c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import gradio as gr
from modeling_llama_seq2seq import LlamaCrossAttentionEncDec
from transformers import AutoTokenizer, AutoConfig
from PIL import Image

# 模型路径和初始化
def load_model(model_name_or_path):
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
    config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
    model = LlamaCrossAttentionEncDec.from_pretrained(model_name_or_path, config=config)
    return tokenizer, model

tokenizer, model = load_model('path_to_your_model')

# 翻译函数
def translate_text(src_text, src_lang, tgt_lang, task_type, term_text=None, mt_text=None):
    if task_type == "常规翻译":
        prompt = f"Translate the following text from {src_lang} into {tgt_lang}.\n{src_lang}: {src_text}\n{tgt_lang}: "
    elif task_type == "术语受限翻译":
        prompt = f"Translate the following text from {src_lang} into {tgt_lang} using the provided terminology pairs.\nTerminology pairs: {term_text}\n{src_lang}: {src_text}\n{tgt_lang}: "
    elif task_type == "自动后期编辑":
        prompt = f"Improve the following machine-generated translation from {src_lang} to {tgt_lang}.\n{src_lang}: {src_text}\n{tgt_lang}: {mt_text}\n{tgt_lang}: "
    else:
        return "请选择正确的任务类型"

    input_ids = tokenizer(prompt, return_tensors="pt")
    outputs_tokenized = model.generate(**input_ids, num_beams=5, do_sample=False)
    outputs = tokenizer.batch_decode(outputs_tokenized, skip_special_tokens=True)
    return outputs[0]

# 构建 Gradio UI
def create_interface():
    logo_image = Image.open('path_to_your_logo/logo.png')

    with gr.Blocks() as demo:
        gr.Image(logo_image, elem_id="logo", label="Logo")
        gr.Markdown("## 🌎 AI 翻译助手")
        
        with gr.Row():
            src_text = gr.Textbox(label="输入文本", placeholder="请输入需要翻译的文本")
        
        with gr.Row():
            src_lang = gr.Dropdown(["English", "Chinese", "French", "German"], label="源语言")
            tgt_lang = gr.Dropdown(["Chinese", "English", "French", "German"], label="目标语言")
        
        with gr.Row():
            task_type = gr.Radio(["常规翻译", "术语受限翻译", "自动后期编辑"], label="任务类型")
            term_text = gr.Textbox(label="术语表(术语受限翻译)", visible=False)
            mt_text = gr.Textbox(label="机器翻译结果(自动后期编辑)", visible=False)
        
        def show_extra_fields(task_type):
            return (task_type == "术语受限翻译", task_type == "自动后期编辑")
        
        task_type.change(show_extra_fields, inputs=[task_type], outputs=[term_text, mt_text])
        output_text = gr.Textbox(label="翻译结果")

        translate_button = gr.Button("翻译")
        translate_button.click(translate_text, inputs=[src_text, src_lang, tgt_lang, task_type, term_text, mt_text], outputs=output_text)

    demo.launch(server_name="0.0.0.0", server_port=8080, share=True)

if __name__ == "__main__":
    create_interface()