DeepLearning101 commited on
Commit
b79c056
·
verified ·
1 Parent(s): 8f16be9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -0
app.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import gradio as gr
4
+ import operator
5
+ import torch
6
+ from transformers import BertTokenizer, BertForMaskedLM
7
+
8
+ # 使用私有模型和分詞器
9
+ model_name_or_path = "DeepLearning101/Corrector101zhTW"
10
+ auth_token = "Corrector101zhTW" # 換成您的 Hugging Face API token
11
+
12
+ tokenizer = BertTokenizer.from_pretrained(model_name_or_path, use_auth_token=auth_token)
13
+ model = BertForMaskedLM.from_pretrained(model_name_or_path, use_auth_token=auth_token)
14
+
15
+
16
+ def ai_text(text):
17
+ with torch.no_grad():
18
+ outputs = model(**tokenizer([text], padding=True, return_tensors='pt'))
19
+
20
+ def to_highlight(corrected_sent, errs):
21
+ output = [{"entity": "糾錯", "word": err[1], "start": err[2], "end": err[3]} for i, err in
22
+ enumerate(errs)]
23
+ return {"text": corrected_sent, "entities": output}
24
+
25
+ def get_errors(corrected_text, origin_text):
26
+ sub_details = []
27
+ for i, ori_char in enumerate(origin_text):
28
+ if ori_char in [' ', '“', '”', '‘', '’', '琊', '\n', '…', '—', '擤']:
29
+ # add unk word
30
+ corrected_text = corrected_text[:i] + ori_char + corrected_text[i:]
31
+ continue
32
+ if i >= len(corrected_text):
33
+ continue
34
+ if ori_char != corrected_text[i]:
35
+ if ori_char.lower() == corrected_text[i]:
36
+ # pass english upper char
37
+ corrected_text = corrected_text[:i] + ori_char + corrected_text[i + 1:]
38
+ continue
39
+ sub_details.append((ori_char, corrected_text[i], i, i + 1))
40
+ sub_details = sorted(sub_details, key=operator.itemgetter(2))
41
+ return corrected_text, sub_details
42
+
43
+ _text = tokenizer.decode(torch.argmax(outputs.logits[0], dim=-1), skip_special_tokens=True).replace(' ', '')
44
+ corrected_text = _text[:len(text)]
45
+ corrected_text, details = get_errors(corrected_text, text)
46
+ print(text, ' => ', corrected_text, details)
47
+ return corrected_text + ' ' + str(details)
48
+
49
+
50
+ if __name__ == '__main__':
51
+
52
+ examples = [
53
+ ['你究輸入利的手機門號跟生分證就可以了。'],
54
+ ['這裡是客服中新,很高性為您服物,請問金天有什麼須要幫忙'],
55
+ ['因為我們這邊是按天術比例計蒜給您的,其實不會有態大的穎響。也就是您用前面的資非的廢率來做計算'],
56
+ ['我來看以下,他的時價是多少?起實您就可以直皆就不用到門事'],
57
+ ['因為你現在月富是六九九嘛,我幫擬減衣百塊,兒且也不會江速'],
58
+ ]
59
+
60
+ inputs=[gr.Textbox(lines=2, label="欲校正的文字")],
61
+ outputs=[gr.Textbox(lines=2, label="修正後的文字")],
62
+ gr.Interface(
63
+ inputs='text',
64
+ outputs='text',
65
+ title="客服ASR文本AI糾錯系統",
66
+ description="""
67
+ <a href="https://www.twman.org" target='_blank'>TonTon Huang Ph.D. @ 2024/04 </a><br>
68
+ 輸入ASR文本,糾正同音字/詞錯誤<br>
69
+ Masked Language Model (MLM) as correction BERT
70
+ """, examples=examples
71
+ ).launch()