import gradio as gr from modeling_llama_seq2seq import LlamaCrossAttentionEncDec from transformers import AutoTokenizer, AutoConfig from PIL import Image # 模型路径和初始化 def load_model(model_name_or_path): tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) model = LlamaCrossAttentionEncDec.from_pretrained(model_name_or_path, config=config) return tokenizer, model tokenizer, model = load_model('path_to_your_model') # 翻译函数 def translate_text(src_text, src_lang, tgt_lang, task_type, term_text=None, mt_text=None): if task_type == "常规翻译": prompt = f"Translate the following text from {src_lang} into {tgt_lang}.\n{src_lang}: {src_text}\n{tgt_lang}: " elif task_type == "术语受限翻译": prompt = f"Translate the following text from {src_lang} into {tgt_lang} using the provided terminology pairs.\nTerminology pairs: {term_text}\n{src_lang}: {src_text}\n{tgt_lang}: " elif task_type == "自动后期编辑": prompt = f"Improve the following machine-generated translation from {src_lang} to {tgt_lang}.\n{src_lang}: {src_text}\n{tgt_lang}: {mt_text}\n{tgt_lang}: " else: return "请选择正确的任务类型" input_ids = tokenizer(prompt, return_tensors="pt") outputs_tokenized = model.generate(**input_ids, num_beams=5, do_sample=False) outputs = tokenizer.batch_decode(outputs_tokenized, skip_special_tokens=True) return outputs[0] # 构建 Gradio UI def create_interface(): logo_image = Image.open('path_to_your_logo/logo.png') with gr.Blocks() as demo: gr.Image(logo_image, elem_id="logo", label="Logo") gr.Markdown("## 🌎 AI 翻译助手") with gr.Row(): src_text = gr.Textbox(label="输入文本", placeholder="请输入需要翻译的文本") with gr.Row(): src_lang = gr.Dropdown(["English", "Chinese", "French", "German"], label="源语言") tgt_lang = gr.Dropdown(["Chinese", "English", "French", "German"], label="目标语言") with gr.Row(): task_type = gr.Radio(["常规翻译", "术语受限翻译", "自动后期编辑"], label="任务类型") term_text = gr.Textbox(label="术语表(术语受限翻译)", visible=False) mt_text = gr.Textbox(label="机器翻译结果(自动后期编辑)", visible=False) def show_extra_fields(task_type): return (task_type == "术语受限翻译", task_type == "自动后期编辑") task_type.change(show_extra_fields, inputs=[task_type], outputs=[term_text, mt_text]) output_text = gr.Textbox(label="翻译结果") translate_button = gr.Button("翻译") translate_button.click(translate_text, inputs=[src_text, src_lang, tgt_lang, task_type, term_text, mt_text], outputs=output_text) demo.launch(server_name="0.0.0.0", server_port=8080, share=True) if __name__ == "__main__": create_interface()