Add application file
Browse files
app.py
CHANGED
@@ -1,7 +1,63 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from modeling_llama_seq2seq import LlamaCrossAttentionEncDec
|
3 |
+
from transformers import AutoTokenizer, AutoConfig
|
4 |
+
from PIL import Image
|
5 |
|
6 |
+
# 模型路径和初始化
|
7 |
+
def load_model(model_name_or_path):
|
8 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
9 |
+
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
|
10 |
+
model = LlamaCrossAttentionEncDec.from_pretrained(model_name_or_path, config=config)
|
11 |
+
return tokenizer, model
|
12 |
|
13 |
+
tokenizer, model = load_model('path_to_your_model')
|
14 |
+
|
15 |
+
# 翻译函数
|
16 |
+
def translate_text(src_text, src_lang, tgt_lang, task_type, term_text=None, mt_text=None):
|
17 |
+
if task_type == "常规翻译":
|
18 |
+
prompt = f"Translate the following text from {src_lang} into {tgt_lang}.\n{src_lang}: {src_text}\n{tgt_lang}: "
|
19 |
+
elif task_type == "术语受限翻译":
|
20 |
+
prompt = f"Translate the following text from {src_lang} into {tgt_lang} using the provided terminology pairs.\nTerminology pairs: {term_text}\n{src_lang}: {src_text}\n{tgt_lang}: "
|
21 |
+
elif task_type == "自动后期编辑":
|
22 |
+
prompt = f"Improve the following machine-generated translation from {src_lang} to {tgt_lang}.\n{src_lang}: {src_text}\n{tgt_lang}: {mt_text}\n{tgt_lang}: "
|
23 |
+
else:
|
24 |
+
return "请选择正确的任务类型"
|
25 |
+
|
26 |
+
input_ids = tokenizer(prompt, return_tensors="pt")
|
27 |
+
outputs_tokenized = model.generate(**input_ids, num_beams=5, do_sample=False)
|
28 |
+
outputs = tokenizer.batch_decode(outputs_tokenized, skip_special_tokens=True)
|
29 |
+
return outputs[0]
|
30 |
+
|
31 |
+
# 构建 Gradio UI
|
32 |
+
def create_interface():
|
33 |
+
logo_image = Image.open('path_to_your_logo/logo.png')
|
34 |
+
|
35 |
+
with gr.Blocks() as demo:
|
36 |
+
gr.Image(logo_image, elem_id="logo", label="Logo")
|
37 |
+
gr.Markdown("## 🌎 AI 翻译助手")
|
38 |
+
|
39 |
+
with gr.Row():
|
40 |
+
src_text = gr.Textbox(label="输入文本", placeholder="请输入需要翻译的文本")
|
41 |
+
|
42 |
+
with gr.Row():
|
43 |
+
src_lang = gr.Dropdown(["English", "Chinese", "French", "German"], label="源语言")
|
44 |
+
tgt_lang = gr.Dropdown(["Chinese", "English", "French", "German"], label="目标语言")
|
45 |
+
|
46 |
+
with gr.Row():
|
47 |
+
task_type = gr.Radio(["常规翻译", "术语受限翻译", "自动后期编辑"], label="任务类型")
|
48 |
+
term_text = gr.Textbox(label="术语表(术语受限翻译)", visible=False)
|
49 |
+
mt_text = gr.Textbox(label="机器翻译结果(自动后期编辑)", visible=False)
|
50 |
+
|
51 |
+
def show_extra_fields(task_type):
|
52 |
+
return (task_type == "术语受限翻译", task_type == "自动后期编辑")
|
53 |
+
|
54 |
+
task_type.change(show_extra_fields, inputs=[task_type], outputs=[term_text, mt_text])
|
55 |
+
output_text = gr.Textbox(label="翻译结果")
|
56 |
+
|
57 |
+
translate_button = gr.Button("翻译")
|
58 |
+
translate_button.click(translate_text, inputs=[src_text, src_lang, tgt_lang, task_type, term_text, mt_text], outputs=output_text)
|
59 |
+
|
60 |
+
demo.launch(server_name="0.0.0.0", server_port=8080, share=True)
|
61 |
+
|
62 |
+
if __name__ == "__main__":
|
63 |
+
create_interface()
|