tianyaogavin's picture
ct2 translator
87b8a8a
'''
翻译模块 - 使用CTranslate2加速的NLLB模型进行多语言翻译
'''
from ctranslate2 import Translator
from transformers import AutoTokenizer
from langdetect import detect
import torch
import time
import logging
# 配置日志
def setup_logger(name, level=logging.INFO):
logger = logging.getLogger(name)
if logger.handlers:
logger.handlers.clear()
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(level)
logger.propagate = False
return logger
logger = setup_logger("translator")
class NLLBTranslator:
def __init__(self, model_dir="nllb-600m-ct2-int8-fp16", default_target="eng_Latn"):
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.debug(f"使用设备: {self.device}")
self.tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
self.translator = Translator(model_dir, device=self.device, compute_type="int8_float16")
self.default_target = default_target
def translate(self, text: str, source_lang_code: str, target_lang_code: str = None) -> str:
logger.debug("开始翻译")
logger.info(f"[翻译原文] {text}")
src_lang = source_lang_code
tgt_lang = target_lang_code or self.default_target
logger.debug(f"源语言: {src_lang}, 目标语言: {tgt_lang}")
# # 使用NLLB的标准格式处理源语言和目标语言
source = self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(text))
start = time.time()
target_prefix = [tgt_lang]
results = self.translator.translate_batch(
[source],
#beam_size=6,
length_penalty=1.2,
target_prefix=[target_prefix]
)
duration = time.time() - start
output_tokens = results[0].hypotheses[0]
logger.debug(f"输出分词: {output_tokens}")
# 转换输出tokens为文本并清理
result = self.tokenizer.convert_tokens_to_string(output_tokens)
result = result.replace("<pad>", "").replace("</s>", "").replace("<s>", "").strip()
for lang_code in ["kor_Hang", "eng_Latn", "zho_Hans", "jpn_Jpan", "fra_Latn", "deu_Latn", "arb_Arab"]:
result = result.replace(lang_code, "").strip()
logger.debug(f"翻译完成: {src_lang} -> {tgt_lang}, 耗时: {duration * 1000:.2f}ms")
logger.info(f"[翻译结果] {result}")
return result
if __name__ == "__main__":
logger.setLevel(logging.DEBUG)
translator = NLLBTranslator()
test_cases = [
# 中文 -> 英文
("请问这附近有地铁站吗?", "zho_Hans", "eng_Latn"),
("我们今天要讨论人工智能的发展趋势。", "zho_Hans", "eng_Latn"),
("他的回答令人非常失望。", "zho_Hans", "eng_Latn"),
("这个项目已经进行了三个月,还需要更多资源支持。", "zho_Hans", "eng_Latn"),
("天气预报说明天会有暴雨,请大家注意安全。", "zho_Hans", "eng_Latn"),
("是时候重新思考我们的计划了。", "zho_Hans", "eng_Latn"),
("我对这个结果非常满意,感谢你的努力。", "zho_Hans", "eng_Latn"),
("她穿着一件红色的连衣裙,在人群中格外显眼。", "zho_Hans", "eng_Latn"),
# 英文 -> 中文
("Can you help me find the nearest bus station?", "eng_Latn", "zho_Hans"),
("The machine learning model achieved an accuracy of 95%.", "eng_Latn", "zho_Hans"),
("He was overwhelmed by the unexpected response from the audience.", "eng_Latn", "zho_Hans"),
("It’s important to stay hydrated during hot summer days.", "eng_Latn", "zho_Hans"),
("Although she was tired, she continued working late into the night.", "eng_Latn", "zho_Hans"),
("The concert was amazing, and the crowd was full of energy.", "eng_Latn", "zho_Hans"),
("Please make sure to submit your application before the deadline.", "eng_Latn", "zho_Hans"),
("After months of preparation, the product was finally launched.", "eng_Latn", "zho_Hans")
]
for i, (text, src_lang, tgt_lang) in enumerate(test_cases):
logger.info(f"\n==== 测试用例 {i + 1} ====")
start_total = time.time()
result = translator.translate(text, source_lang_code=src_lang, target_lang_code=tgt_lang)
end_total = time.time()
logger.info(f"最终翻译结果: {result}")
logger.info(f"总耗时: {(end_total - start_total) * 1000:.2f}ms")