|
import gradio as gr |
|
from modeling_llama_seq2seq import LlamaCrossAttentionEncDec |
|
from transformers import AutoTokenizer, AutoConfig |
|
from PIL import Image |
|
|
|
|
|
def load_model(model_name_or_path): |
|
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) |
|
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) |
|
model = LlamaCrossAttentionEncDec.from_pretrained(model_name_or_path, config=config) |
|
return tokenizer, model |
|
|
|
tokenizer, model = load_model('path_to_your_model') |
|
|
|
|
|
def translate_text(src_text, src_lang, tgt_lang, task_type, term_text=None, mt_text=None): |
|
if task_type == "常规翻译": |
|
prompt = f"Translate the following text from {src_lang} into {tgt_lang}.\n{src_lang}: {src_text}\n{tgt_lang}: " |
|
elif task_type == "术语受限翻译": |
|
prompt = f"Translate the following text from {src_lang} into {tgt_lang} using the provided terminology pairs.\nTerminology pairs: {term_text}\n{src_lang}: {src_text}\n{tgt_lang}: " |
|
elif task_type == "自动后期编辑": |
|
prompt = f"Improve the following machine-generated translation from {src_lang} to {tgt_lang}.\n{src_lang}: {src_text}\n{tgt_lang}: {mt_text}\n{tgt_lang}: " |
|
else: |
|
return "请选择正确的任务类型" |
|
|
|
input_ids = tokenizer(prompt, return_tensors="pt") |
|
outputs_tokenized = model.generate(**input_ids, num_beams=5, do_sample=False) |
|
outputs = tokenizer.batch_decode(outputs_tokenized, skip_special_tokens=True) |
|
return outputs[0] |
|
|
|
|
|
def create_interface(): |
|
logo_image = Image.open('path_to_your_logo/logo.png') |
|
|
|
with gr.Blocks() as demo: |
|
gr.Image(logo_image, elem_id="logo", label="Logo") |
|
gr.Markdown("## 🌎 AI 翻译助手") |
|
|
|
with gr.Row(): |
|
src_text = gr.Textbox(label="输入文本", placeholder="请输入需要翻译的文本") |
|
|
|
with gr.Row(): |
|
src_lang = gr.Dropdown(["English", "Chinese", "French", "German"], label="源语言") |
|
tgt_lang = gr.Dropdown(["Chinese", "English", "French", "German"], label="目标语言") |
|
|
|
with gr.Row(): |
|
task_type = gr.Radio(["常规翻译", "术语受限翻译", "自动后期编辑"], label="任务类型") |
|
term_text = gr.Textbox(label="术语表(术语受限翻译)", visible=False) |
|
mt_text = gr.Textbox(label="机器翻译结果(自动后期编辑)", visible=False) |
|
|
|
def show_extra_fields(task_type): |
|
return (task_type == "术语受限翻译", task_type == "自动后期编辑") |
|
|
|
task_type.change(show_extra_fields, inputs=[task_type], outputs=[term_text, mt_text]) |
|
output_text = gr.Textbox(label="翻译结果") |
|
|
|
translate_button = gr.Button("翻译") |
|
translate_button.click(translate_text, inputs=[src_text, src_lang, tgt_lang, task_type, term_text, mt_text], outputs=output_text) |
|
|
|
demo.launch(server_name="0.0.0.0", server_port=8080, share=True) |
|
|
|
if __name__ == "__main__": |
|
create_interface() |
|
|