File size: 4,697 Bytes
87b8a8a f62e484 87b8a8a f62e484 1bf36cc f62e484 87b8a8a 1bf36cc f62e484 87b8a8a f62e484 87b8a8a 1bf36cc 87b8a8a f62e484 1bf36cc 87b8a8a 1bf36cc 87b8a8a 1bf36cc 87b8a8a 1bf36cc 87b8a8a 1bf36cc f62e484 87b8a8a f62e484 1bf36cc f62e484 87b8a8a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
'''
翻译模块 - 使用CTranslate2加速的NLLB模型进行多语言翻译
'''
from ctranslate2 import Translator
from transformers import AutoTokenizer
from langdetect import detect
import torch
import time
import logging
# 配置日志
def setup_logger(name, level=logging.INFO):
logger = logging.getLogger(name)
if logger.handlers:
logger.handlers.clear()
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(level)
logger.propagate = False
return logger
logger = setup_logger("translator")
class NLLBTranslator:
def __init__(self, model_dir="nllb-600m-ct2-int8-fp16", default_target="eng_Latn"):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.debug(f"使用设备: {self.device}")
self.tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
self.translator = Translator(model_dir, device=self.device, compute_type="int8_float16")
self.default_target = default_target
def translate(self, text: str, source_lang_code: str, target_lang_code: str = None) -> str:
logger.debug("开始翻译")
logger.info(f"[翻译原文] {text}")
src_lang = source_lang_code
tgt_lang = target_lang_code or self.default_target
logger.debug(f"源语言: {src_lang}, 目标语言: {tgt_lang}")
# # 使用NLLB的标准格式处理源语言和目标语言
source = self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(text))
start = time.time()
target_prefix = [tgt_lang]
results = self.translator.translate_batch(
[source],
#beam_size=6,
length_penalty=1.2,
target_prefix=[target_prefix]
)
duration = time.time() - start
output_tokens = results[0].hypotheses[0]
logger.debug(f"输出分词: {output_tokens}")
# 转换输出tokens为文本并清理
result = self.tokenizer.convert_tokens_to_string(output_tokens)
result = result.replace("<pad>", "").replace("</s>", "").replace("<s>", "").strip()
for lang_code in ["kor_Hang", "eng_Latn", "zho_Hans", "jpn_Jpan", "fra_Latn", "deu_Latn", "arb_Arab"]:
result = result.replace(lang_code, "").strip()
logger.debug(f"翻译完成: {src_lang} -> {tgt_lang}, 耗时: {duration * 1000:.2f}ms")
logger.info(f"[翻译结果] {result}")
return result
if __name__ == "__main__":
logger.setLevel(logging.DEBUG)
translator = NLLBTranslator()
test_cases = [
# 中文 -> 英文
("请问这附近有地铁站吗?", "zho_Hans", "eng_Latn"),
("我们今天要讨论人工智能的发展趋势。", "zho_Hans", "eng_Latn"),
("他的回答令人非常失望。", "zho_Hans", "eng_Latn"),
("这个项目已经进行了三个月,还需要更多资源支持。", "zho_Hans", "eng_Latn"),
("天气预报说明天会有暴雨,请大家注意安全。", "zho_Hans", "eng_Latn"),
("是时候重新思考我们的计划了。", "zho_Hans", "eng_Latn"),
("我对这个结果非常满意,感谢你的努力。", "zho_Hans", "eng_Latn"),
("她穿着一件红色的连衣裙,在人群中格外显眼。", "zho_Hans", "eng_Latn"),
# 英文 -> 中文
("Can you help me find the nearest bus station?", "eng_Latn", "zho_Hans"),
("The machine learning model achieved an accuracy of 95%.", "eng_Latn", "zho_Hans"),
("He was overwhelmed by the unexpected response from the audience.", "eng_Latn", "zho_Hans"),
("It’s important to stay hydrated during hot summer days.", "eng_Latn", "zho_Hans"),
("Although she was tired, she continued working late into the night.", "eng_Latn", "zho_Hans"),
("The concert was amazing, and the crowd was full of energy.", "eng_Latn", "zho_Hans"),
("Please make sure to submit your application before the deadline.", "eng_Latn", "zho_Hans"),
("After months of preparation, the product was finally launched.", "eng_Latn", "zho_Hans")
]
for i, (text, src_lang, tgt_lang) in enumerate(test_cases):
logger.info(f"\n==== 测试用例 {i + 1} ====")
start_total = time.time()
result = translator.translate(text, source_lang_code=src_lang, target_lang_code=tgt_lang)
end_total = time.time()
logger.info(f"最终翻译结果: {result}")
logger.info(f"总耗时: {(end_total - start_total) * 1000:.2f}ms")
|