Tbb1111 commited on
Commit
6928084
·
verified ·
1 Parent(s): db3dca1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -10
app.py CHANGED
@@ -1,16 +1,16 @@
1
  import gradio as gr
2
- from transformers import T5ForConditionalGeneration, T5Tokenizer
3
 
4
- # 加载 T5 模型和分词器
5
- model_name = "t5-small" # 可以根据需要调整模型大小
6
- model = T5ForConditionalGeneration.from_pretrained(model_name)
7
- tokenizer = T5Tokenizer.from_pretrained(model_name)
8
 
9
  # 翻译功能
10
  def translate_text(input_text):
11
- # 使用 T5 模型进行翻译
12
- inputs = tokenizer.encode("翻译成中文: " + input_text, return_tensors="pt", max_length=512, truncation=True)
13
- outputs = model.generate(inputs, max_length=1024, num_beams=4, early_stopping=True)
14
  translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
15
 
16
  return translated_text
@@ -20,10 +20,10 @@ with gr.Blocks() as demo:
20
  gr.Markdown("# 英文文本翻译器")
21
 
22
  with gr.Row():
23
- text_input = gr.Textbox(label="输入英文文本", lines=5) # 让用户输入英文文本
24
 
25
  translate_button = gr.Button("开始翻译")
26
- output_text = gr.Textbox(label="翻译后的中文文本", lines=5) # 显示翻译后的中文文本
27
 
28
  translate_button.click(fn=translate_text, inputs=text_input, outputs=output_text)
29
 
 
1
  import gradio as gr
2
+ from transformers import MarianMTModel, MarianTokenizer
3
 
4
+ # 加载 MarianMT 模型和分词器
5
+ model_name = "Helsinki-NLP/opus-mt-en-zh"
6
+ model = MarianMTModel.from_pretrained(model_name).to("cpu") # 强制使用 CPU
7
+ tokenizer = MarianTokenizer.from_pretrained(model_name)
8
 
9
  # 翻译功能
10
  def translate_text(input_text):
11
+ # 使用 MarianMT 模型进行翻译
12
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)
13
+ outputs = model.generate(**inputs)
14
  translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
15
 
16
  return translated_text
 
20
  gr.Markdown("# 英文文本翻译器")
21
 
22
  with gr.Row():
23
+ text_input = gr.Textbox(label="输入英文文本", lines=5)
24
 
25
  translate_button = gr.Button("开始翻译")
26
+ output_text = gr.Textbox(label="翻译后的中文文本", lines=5)
27
 
28
  translate_button.click(fn=translate_text, inputs=text_input, outputs=output_text)
29