|
''' |
|
翻译模块 - 使用CTranslate2加速的NLLB模型进行多语言翻译 |
|
''' |
|
|
|
from ctranslate2 import Translator |
|
from transformers import AutoTokenizer |
|
from langdetect import detect |
|
import torch |
|
import time |
|
import logging |
|
|
|
|
|
def setup_logger(name, level=logging.INFO): |
|
logger = logging.getLogger(name) |
|
if logger.handlers: |
|
logger.handlers.clear() |
|
handler = logging.StreamHandler() |
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
handler.setFormatter(formatter) |
|
logger.addHandler(handler) |
|
logger.setLevel(level) |
|
logger.propagate = False |
|
return logger |
|
|
|
logger = setup_logger("translator") |
|
|
|
class NLLBTranslator: |
|
def __init__(self, model_dir="nllb-600m-ct2-int8-fp16", default_target="eng_Latn"): |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
logger.debug(f"使用设备: {self.device}") |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") |
|
self.translator = Translator(model_dir, device=self.device, compute_type="int8_float16") |
|
self.default_target = default_target |
|
|
|
def translate(self, text: str, source_lang_code: str, target_lang_code: str = None) -> str: |
|
logger.debug("开始翻译") |
|
logger.info(f"[翻译原文] {text}") |
|
|
|
src_lang = source_lang_code |
|
tgt_lang = target_lang_code or self.default_target |
|
|
|
logger.debug(f"源语言: {src_lang}, 目标语言: {tgt_lang}") |
|
|
|
|
|
source = self.tokenizer.convert_ids_to_tokens(self.tokenizer.encode(text)) |
|
|
|
start = time.time() |
|
target_prefix = [tgt_lang] |
|
results = self.translator.translate_batch( |
|
[source], |
|
|
|
length_penalty=1.2, |
|
target_prefix=[target_prefix] |
|
) |
|
duration = time.time() - start |
|
|
|
output_tokens = results[0].hypotheses[0] |
|
logger.debug(f"输出分词: {output_tokens}") |
|
|
|
|
|
result = self.tokenizer.convert_tokens_to_string(output_tokens) |
|
result = result.replace("<pad>", "").replace("</s>", "").replace("<s>", "").strip() |
|
for lang_code in ["kor_Hang", "eng_Latn", "zho_Hans", "jpn_Jpan", "fra_Latn", "deu_Latn", "arb_Arab"]: |
|
result = result.replace(lang_code, "").strip() |
|
|
|
logger.debug(f"翻译完成: {src_lang} -> {tgt_lang}, 耗时: {duration * 1000:.2f}ms") |
|
logger.info(f"[翻译结果] {result}") |
|
|
|
return result |
|
|
|
if __name__ == "__main__": |
|
logger.setLevel(logging.DEBUG) |
|
translator = NLLBTranslator() |
|
|
|
test_cases = [ |
|
|
|
("请问这附近有地铁站吗?", "zho_Hans", "eng_Latn"), |
|
("我们今天要讨论人工智能的发展趋势。", "zho_Hans", "eng_Latn"), |
|
("他的回答令人非常失望。", "zho_Hans", "eng_Latn"), |
|
("这个项目已经进行了三个月,还需要更多资源支持。", "zho_Hans", "eng_Latn"), |
|
("天气预报说明天会有暴雨,请大家注意安全。", "zho_Hans", "eng_Latn"), |
|
("是时候重新思考我们的计划了。", "zho_Hans", "eng_Latn"), |
|
("我对这个结果非常满意,感谢你的努力。", "zho_Hans", "eng_Latn"), |
|
("她穿着一件红色的连衣裙,在人群中格外显眼。", "zho_Hans", "eng_Latn"), |
|
|
|
|
|
("Can you help me find the nearest bus station?", "eng_Latn", "zho_Hans"), |
|
("The machine learning model achieved an accuracy of 95%.", "eng_Latn", "zho_Hans"), |
|
("He was overwhelmed by the unexpected response from the audience.", "eng_Latn", "zho_Hans"), |
|
("It’s important to stay hydrated during hot summer days.", "eng_Latn", "zho_Hans"), |
|
("Although she was tired, she continued working late into the night.", "eng_Latn", "zho_Hans"), |
|
("The concert was amazing, and the crowd was full of energy.", "eng_Latn", "zho_Hans"), |
|
("Please make sure to submit your application before the deadline.", "eng_Latn", "zho_Hans"), |
|
("After months of preparation, the product was finally launched.", "eng_Latn", "zho_Hans") |
|
] |
|
|
|
|
|
for i, (text, src_lang, tgt_lang) in enumerate(test_cases): |
|
logger.info(f"\n==== 测试用例 {i + 1} ====") |
|
start_total = time.time() |
|
result = translator.translate(text, source_lang_code=src_lang, target_lang_code=tgt_lang) |
|
end_total = time.time() |
|
logger.info(f"最终翻译结果: {result}") |
|
logger.info(f"总耗时: {(end_total - start_total) * 1000:.2f}ms") |
|
|