File size: 3,094 Bytes
0fe65c0 fe9a8c3 0fe65c0 fe9a8c3 0fe65c0 fe9a8c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import gradio as gr
from modeling_llama_seq2seq import LlamaCrossAttentionEncDec
from transformers import AutoTokenizer, AutoConfig
from PIL import Image
# 模型路径和初始化
def load_model(model_name_or_path):
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
model = LlamaCrossAttentionEncDec.from_pretrained(model_name_or_path, config=config)
return tokenizer, model
tokenizer, model = load_model('path_to_your_model')
# 翻译函数
def translate_text(src_text, src_lang, tgt_lang, task_type, term_text=None, mt_text=None):
if task_type == "常规翻译":
prompt = f"Translate the following text from {src_lang} into {tgt_lang}.\n{src_lang}: {src_text}\n{tgt_lang}: "
elif task_type == "术语受限翻译":
prompt = f"Translate the following text from {src_lang} into {tgt_lang} using the provided terminology pairs.\nTerminology pairs: {term_text}\n{src_lang}: {src_text}\n{tgt_lang}: "
elif task_type == "自动后期编辑":
prompt = f"Improve the following machine-generated translation from {src_lang} to {tgt_lang}.\n{src_lang}: {src_text}\n{tgt_lang}: {mt_text}\n{tgt_lang}: "
else:
return "请选择正确的任务类型"
input_ids = tokenizer(prompt, return_tensors="pt")
outputs_tokenized = model.generate(**input_ids, num_beams=5, do_sample=False)
outputs = tokenizer.batch_decode(outputs_tokenized, skip_special_tokens=True)
return outputs[0]
# 构建 Gradio UI
def create_interface():
logo_image = Image.open('path_to_your_logo/logo.png')
with gr.Blocks() as demo:
gr.Image(logo_image, elem_id="logo", label="Logo")
gr.Markdown("## 🌎 AI 翻译助手")
with gr.Row():
src_text = gr.Textbox(label="输入文本", placeholder="请输入需要翻译的文本")
with gr.Row():
src_lang = gr.Dropdown(["English", "Chinese", "French", "German"], label="源语言")
tgt_lang = gr.Dropdown(["Chinese", "English", "French", "German"], label="目标语言")
with gr.Row():
task_type = gr.Radio(["常规翻译", "术语受限翻译", "自动后期编辑"], label="任务类型")
term_text = gr.Textbox(label="术语表(术语受限翻译)", visible=False)
mt_text = gr.Textbox(label="机器翻译结果(自动后期编辑)", visible=False)
def show_extra_fields(task_type):
return (task_type == "术语受限翻译", task_type == "自动后期编辑")
task_type.change(show_extra_fields, inputs=[task_type], outputs=[term_text, mt_text])
output_text = gr.Textbox(label="翻译结果")
translate_button = gr.Button("翻译")
translate_button.click(translate_text, inputs=[src_text, src_lang, tgt_lang, task_type, term_text, mt_text], outputs=output_text)
demo.launch(server_name="0.0.0.0", server_port=8080, share=True)
if __name__ == "__main__":
create_interface()
|