|
""" |
|
python interactive.py |
|
""" |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
from transformers import TextClassificationPipeline |
|
import gradio as gr |
|
|
|
|
|
MODEL_NAME = 'momo/KcBERT-base_Hate_speech_Privacy_Detection' |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
model = AutoModelForSequenceClassification.from_pretrained( |
|
MODEL_NAME, |
|
num_labels= 15, |
|
problem_type="multi_label_classification" |
|
) |
|
|
|
MODEL_BUF = { |
|
"name": MODEL_NAME, |
|
"tokenizer": tokenizer, |
|
"model": model, |
|
} |
|
|
|
def change_model_name(name): |
|
MODEL_BUF["name"] = name |
|
MODEL_BUF["tokenizer"] = AutoTokenizer.from_pretrained(name) |
|
MODEL_BUF["model"] = AutoModelForSequenceClassification.from_pretrained(name) |
|
|
|
def predict(model_name, text): |
|
if model_name != MODEL_BUF["name"]: |
|
change_model_name(model_name) |
|
|
|
tokenizer = MODEL_BUF["tokenizer"] |
|
model = MODEL_BUF["model"] |
|
|
|
unsmile_labels = ["์ฌ์ฑ/๊ฐ์กฑ","๋จ์ฑ","์ฑ์์์","์ธ์ข
/๊ตญ์ ","์ฐ๋ น","์ง์ญ","์ข
๊ต","๊ธฐํ ํ์ค","์
ํ/์์ค", "clean", '์ด๋ฆ', '์ ํ๋ฒํธ', '์ฃผ์', '๊ณ์ข๋ฒํธ', '์ฃผ๋ฏผ๋ฒํธ'] |
|
num_labels = len(unsmile_labels) |
|
|
|
model.config.id2label = {i: label for i, label in zip(range(num_labels), unsmile_labels)} |
|
model.config.label2id = {label: i for i, label in zip(range(num_labels), unsmile_labels)} |
|
|
|
pipe = TextClassificationPipeline( |
|
model = model, |
|
tokenizer = tokenizer, |
|
return_all_scores=True, |
|
function_to_apply='sigmoid' |
|
) |
|
print(pipe(text)[0]) |
|
|
|
output = [] |
|
for i in pipe(text)[0]: |
|
output.append(i + '\t') |
|
|
|
return output |
|
|
|
if __name__ == '__main__': |
|
exam1 = '๊ฒฝ๊ธฐ๋ ์ฑ๋จ์ ์์ ๊ตฌ ํํ3๋์ ์ฐ๋ฆฌ ๋๋ค์ผ!' |
|
exam2 = '๋ด ํธ๋ํฐ ๋ฒํธ๋ 010-3930-8237 ์ด์ผ!' |
|
exam3 = '์ ์ ์ฅ ๋๋ฌด ์ง์ฆ๋๋ค' |
|
|
|
model_name_list = [ |
|
'momo/KcELECTRA-base_Hate_speech_Privacy_Detection', |
|
"momo/KcBERT-base_Hate_speech_Privacy_Detection", |
|
] |
|
|
|
|
|
app = gr.Interface( |
|
fn=predict, |
|
inputs=[gr.inputs.Dropdown(model_name_list, label="Model Name"), 'text'], |
|
outputs='text', |
|
examples = [ |
|
[MODEL_BUF["name"], exam1], |
|
[MODEL_BUF["name"], exam2], |
|
[MODEL_BUF["name"], exam3] |
|
], |
|
title="ํ๊ตญ์ด ํ์คํํ, ๊ฐ์ธ์ ๋ณด ํ๋ณ๊ธฐ (Korean Hate Speech and Privacy Detection)", |
|
description="Korean Hate Speech and Privacy Detection. \t 15๊ฐ label Detection: ์ฌ์ฑ/๊ฐ์กฑ, ๋จ์ฑ, ์ฑ์์์, ์ธ์ข
/๊ตญ์ , ์ฐ๋ น, ์ง์ญ, ์ข
๊ต, ๊ธฐํ ํ์ค, ์
ํ/์์ค, clean, name, number, address, bank, person" |
|
) |
|
app.launch() |