Update app.py
Browse filesCaching examples at: '/home/user/app/gradio_cached_examples/15'
Caching example 1/5
Traceback (most recent call last):
File "/home/user/app/app.py", line 50, in <module>
gr.Interface(
File "/usr/local/lib/python3.10/site-packages/gradio/interface.py", line 515, in __init__
self.render_examples()
File "/usr/local/lib/python3.10/site-packages/gradio/interface.py", line 861, in render_examples
self.examples_handler = Examples(
File "/usr/local/lib/python3.10/site-packages/gradio/helpers.py", line 74, in create_examples
examples_obj.create()
File "/usr/local/lib/python3.10/site-packages/gradio/helpers.py", line 314, in create
self._start_caching()
File "/usr/local/lib/python3.10/site-packages/gradio/helpers.py", line 365, in _start_caching
client_utils.synchronize_async(self.cache)
File "/usr/local/lib/python3.10/site-packages/gradio_client/utils.py", line 858, in synchronize_async
return fsspec.asyn.sync(fsspec.asyn.get_loop(), func, *args, **kwargs) # type: ignore
File "/usr/local/lib/python3.10/site-packages/fsspec/asyn.py", line 103, in sync
raise return_result
File "/usr/local/lib/python3.10/site-packages/fsspec/asyn.py", line 56, in _runner
result[0] = await coro
File "/usr/local/lib/python3.10/site-packages/gradio/helpers.py", line 486, in cache
prediction = await Context.root_block.process_api(
File "/usr/local/lib/python3.10/site-packages/gradio/blocks.py", line 1847, in process_api
result = await self.call_function(
File "/usr/local/lib/python3.10/site-packages/gradio/blocks.py", line 1433, in call_function
prediction = await anyio.to_thread.run_sync(
File "/usr/local/lib/python3.10/site-packages/anyio/to_thread.py", line 56, in run_sync
return await get_async_backend().run_sync_in_worker_thread(
File "/usr/local/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 2144, in run_sync_in_worker_thread
return await future
File "/usr/local/lib/python3.10/site-packages/anyio/_backends/_asyncio.py", line 851, in run
result = context.run(func, *args)
File "/usr/local/lib/python3.10/site-packages/gradio/utils.py", line 788, in wrapper
response = f(*args, **kwargs)
File "/home/user/app/app.py", line 21, in ai_text
corrected_text, details = get_errors(text)
File "/home/user/app/app.py", line 30, in get_errors
corrected_text = tokenizer.decode(torch.argmax(outputs.logits[0], dim=-1), skip_special_tokens=True).replace(' ', '')
NameError: name 'outputs' is not defined
錯誤是因為在get_errors函數中引用了未定義的變數outputs。這個變數在ai_text函數中定義,但並沒有傳遞到get_errors函數中。為了解決這個問題,您需要將outputs作為參數傳遞到get_errors函數中,或在get_errors函數內部重新計算它。
@@ -11,25 +11,24 @@ model_name_or_path = "DeepLearning101/Corrector101zhTW"
|
|
11 |
try:
|
12 |
tokenizer = BertTokenizer.from_pretrained(model_name_or_path)
|
13 |
model = BertForMaskedLM.from_pretrained(model_name_or_path)
|
|
|
14 |
except Exception as e:
|
15 |
print(f"加載模型或分詞器失敗,錯誤信息:{e}")
|
16 |
exit(1)
|
17 |
|
18 |
def ai_text(text):
|
|
|
19 |
with torch.no_grad():
|
20 |
-
|
21 |
-
|
|
|
22 |
return corrected_text + ' ' + str(details)
|
23 |
|
24 |
-
def
|
25 |
-
|
26 |
-
return {"text": corrected_sent, "entities": output}
|
27 |
-
|
28 |
-
def get_errors(text):
|
29 |
sub_details = []
|
30 |
corrected_text = tokenizer.decode(torch.argmax(outputs.logits[0], dim=-1), skip_special_tokens=True).replace(' ', '')
|
31 |
for i, ori_char in enumerate(text):
|
32 |
-
# 略過特定字符
|
33 |
if ori_char in [' ', '“', '”', '‘', '’', '琊', '\n', '…', '—', '擤']:
|
34 |
continue
|
35 |
if i >= len(corrected_text):
|
@@ -52,8 +51,8 @@ if __name__ == '__main__':
|
|
52 |
inputs=gr.Textbox(lines=2, label="欲校正的文字"),
|
53 |
outputs=gr.Textbox(lines=2, label="修正後的文字"),
|
54 |
title="客服ASR文本AI糾錯系統",
|
55 |
-
description="""<a href=
|
56 |
輸入ASR文本,糾正同音字/詞錯誤<br>
|
57 |
Masked Language Model (MLM) as correction BERT""",
|
58 |
examples=examples
|
59 |
-
).launch()
|
|
|
11 |
try:
|
12 |
tokenizer = BertTokenizer.from_pretrained(model_name_or_path)
|
13 |
model = BertForMaskedLM.from_pretrained(model_name_or_path)
|
14 |
+
model.eval() # 將模型設置為評估模式
|
15 |
except Exception as e:
|
16 |
print(f"加載模型或分詞器失敗,錯誤信息:{e}")
|
17 |
exit(1)
|
18 |
|
19 |
def ai_text(text):
|
20 |
+
"""處理輸入文本並返回修正後的文本及錯誤細節"""
|
21 |
with torch.no_grad():
|
22 |
+
inputs = tokenizer(text, return_tensors="pt", padding=True)
|
23 |
+
outputs = model(**inputs)
|
24 |
+
corrected_text, details = get_errors(text, outputs)
|
25 |
return corrected_text + ' ' + str(details)
|
26 |
|
27 |
+
def get_errors(text, outputs):
|
28 |
+
"""識別原始文本和模型輸出之間的差異"""
|
|
|
|
|
|
|
29 |
sub_details = []
|
30 |
corrected_text = tokenizer.decode(torch.argmax(outputs.logits[0], dim=-1), skip_special_tokens=True).replace(' ', '')
|
31 |
for i, ori_char in enumerate(text):
|
|
|
32 |
if ori_char in [' ', '“', '”', '‘', '’', '琊', '\n', '…', '—', '擤']:
|
33 |
continue
|
34 |
if i >= len(corrected_text):
|
|
|
51 |
inputs=gr.Textbox(lines=2, label="欲校正的文字"),
|
52 |
outputs=gr.Textbox(lines=2, label="修正後的文字"),
|
53 |
title="客服ASR文本AI糾錯系統",
|
54 |
+
description="""<a href='https://www.twman.org' target='_blank'>TonTon Huang Ph.D. @ 2024/04 </a><br>
|
55 |
輸入ASR文本,糾正同音字/詞錯誤<br>
|
56 |
Masked Language Model (MLM) as correction BERT""",
|
57 |
examples=examples
|
58 |
+
).launch()
|