DeepLearning101's picture
Update app.py
237df5d verified
raw
history blame
2.64 kB
import gradio as gr
import operator
import torch
import os
from transformers import BertTokenizer, BertForMaskedLM
# 使用私有模型和分詞器
model_name_or_path = "DeepLearning101/Corrector101zhTW"
auth_token = os.getenv("HF_HOME")
# 嘗試加載模型和分詞器
try:
tokenizer = BertTokenizer.from_pretrained(model_name_or_path, use_auth_token=auth_token)
model = BertForMaskedLM.from_pretrained(model_name_or_path, use_auth_token=auth_token)
model.eval()
except Exception as e:
print(f"加載模型或分詞器失敗,錯誤信息:{e}")
exit(1)
def ai_text(text):
"""處理輸入文本並返回修正後的文本及錯誤細節"""
with torch.no_grad():
inputs = tokenizer(text, return_tensors="pt", padding=True)
outputs = model(**inputs)
corrected_text, details = get_errors(text, outputs)
return corrected_text + ' ' + str(details)
def get_errors(text, outputs):
"""識別原始文本和模型輸出之間的差異"""
sub_details = []
corrected_text = tokenizer.decode(torch.argmax(outputs.logits[0], dim=-1), skip_special_tokens=True).replace(' ', '')
for i, ori_char in enumerate(text):
if ori_char in [' ', '“', '”', '‘', '’', '琊', '\n', '…', '—', '擤']:
continue
if i >= len(corrected_text):
continue
if ori_char != corrected_text[i]:
sub_details.append((ori_char, corrected_text[i], i, i + 1))
sub_details = sorted(sub_details, key=operator.itemgetter(2))
return corrected_text, sub_details
if __name__ == '__main__':
examples = [
['你究輸入利的手機門號跟生分證就可以了。'],
['這裡是客服中新,很高性為您服物,請問金天有什麼須要幫忙'],
['因為我們這邊是按天術比例計蒜給您的,其實不會有態大的穎響。也就是您用前面的資非的廢率來做計算'],
['我來看以下,他的時價是多少?起實您就可以直皆就不用到門事'],
['因為你現在月富是六九九嘛,我幫擬減衣百塊,兒且也不會江速'],
]
gr.Interface(
fn=ai_text,
inputs=gr.Textbox(lines=2, label="欲校正的文字"),
outputs=gr.Textbox(lines=2, label="修正後的文字"),
title="客服ASR文本AI糾錯系統",
description="""<a href='https://www.twman.org' target='_blank'>TonTon Huang Ph.D. @ 2024/04 </a><br>
輸入ASR文本,糾正同音字/詞錯誤<br>
Masked Language Model (MLM) as correction BERT""",
examples=examples
).launch()