lx-Meteor commited on
Commit
fe9a8c3
·
1 Parent(s): 0fe65c0

Add application file

Browse files
Files changed (1) hide show
  1. app.py +60 -4
app.py CHANGED
@@ -1,7 +1,63 @@
1
  import gradio as gr
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from modeling_llama_seq2seq import LlamaCrossAttentionEncDec
3
+ from transformers import AutoTokenizer, AutoConfig
4
+ from PIL import Image
5
 
6
+ # 模型路径和初始化
7
+ def load_model(model_name_or_path):
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
9
+ config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
10
+ model = LlamaCrossAttentionEncDec.from_pretrained(model_name_or_path, config=config)
11
+ return tokenizer, model
12
 
13
+ tokenizer, model = load_model('path_to_your_model')
14
+
15
+ # 翻译函数
16
+ def translate_text(src_text, src_lang, tgt_lang, task_type, term_text=None, mt_text=None):
17
+ if task_type == "常规翻译":
18
+ prompt = f"Translate the following text from {src_lang} into {tgt_lang}.\n{src_lang}: {src_text}\n{tgt_lang}: "
19
+ elif task_type == "术语受限翻译":
20
+ prompt = f"Translate the following text from {src_lang} into {tgt_lang} using the provided terminology pairs.\nTerminology pairs: {term_text}\n{src_lang}: {src_text}\n{tgt_lang}: "
21
+ elif task_type == "自动后期编辑":
22
+ prompt = f"Improve the following machine-generated translation from {src_lang} to {tgt_lang}.\n{src_lang}: {src_text}\n{tgt_lang}: {mt_text}\n{tgt_lang}: "
23
+ else:
24
+ return "请选择正确的任务类型"
25
+
26
+ input_ids = tokenizer(prompt, return_tensors="pt")
27
+ outputs_tokenized = model.generate(**input_ids, num_beams=5, do_sample=False)
28
+ outputs = tokenizer.batch_decode(outputs_tokenized, skip_special_tokens=True)
29
+ return outputs[0]
30
+
31
+ # 构建 Gradio UI
32
+ def create_interface():
33
+ logo_image = Image.open('path_to_your_logo/logo.png')
34
+
35
+ with gr.Blocks() as demo:
36
+ gr.Image(logo_image, elem_id="logo", label="Logo")
37
+ gr.Markdown("## 🌎 AI 翻译助手")
38
+
39
+ with gr.Row():
40
+ src_text = gr.Textbox(label="输入文本", placeholder="请输入需要翻译的文本")
41
+
42
+ with gr.Row():
43
+ src_lang = gr.Dropdown(["English", "Chinese", "French", "German"], label="源语言")
44
+ tgt_lang = gr.Dropdown(["Chinese", "English", "French", "German"], label="目标语言")
45
+
46
+ with gr.Row():
47
+ task_type = gr.Radio(["常规翻译", "术语受限翻译", "自动后期编辑"], label="任务类型")
48
+ term_text = gr.Textbox(label="术语表(术语受限翻译)", visible=False)
49
+ mt_text = gr.Textbox(label="机器翻译结果(自动后期编辑)", visible=False)
50
+
51
+ def show_extra_fields(task_type):
52
+ return (task_type == "术语受限翻译", task_type == "自动后期编辑")
53
+
54
+ task_type.change(show_extra_fields, inputs=[task_type], outputs=[term_text, mt_text])
55
+ output_text = gr.Textbox(label="翻译结果")
56
+
57
+ translate_button = gr.Button("翻译")
58
+ translate_button.click(translate_text, inputs=[src_text, src_lang, tgt_lang, task_type, term_text, mt_text], outputs=output_text)
59
+
60
+ demo.launch(server_name="0.0.0.0", server_port=8080, share=True)
61
+
62
+ if __name__ == "__main__":
63
+ create_interface()