Spaces:
Runtime error
Runtime error
Commit
·
b767f00
1
Parent(s):
e186fb5
Squashed commit of the following:
Browse filescommit 8e43136c9db5455c3248a2d95aa18a4e7a25bd39
Author: Ardhy Satrio <[email protected]>
Date: Mon May 29 11:15:13 2023 +0800
multi input update
- app/main.py +168 -56
app/main.py
CHANGED
@@ -1,15 +1,33 @@
|
|
1 |
-
|
2 |
-
import
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
# Config
|
8 |
MAX_SEQ_LEN = 128
|
9 |
-
bert_path = '
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
13 |
# "kadabengaran/IndoBERT-BiLSTM-Useful-App-Review"]
|
14 |
HIDDEN_DIM = 768
|
15 |
OUTPUT_DIM = 2 # 2 if Binary
|
@@ -37,15 +55,9 @@ def load_tokenizer(model_path):
|
|
37 |
|
38 |
|
39 |
def remove_special_characters(text):
|
40 |
-
# menghapus karakter khusus kecuali tanda baca seperti titik, koma, dan tanda tanya
|
41 |
-
# text = re.sub(r"[^a-zA-Z0-9.,!?]+", " ", text)
|
42 |
text = re.sub(r'[^a-zA-Z0-9\s]', '', text)
|
43 |
-
|
44 |
-
# text = re.sub(r"'\s+|\s+'", " ", text) # replace apostrophe with space if it's surrounded by whitespace
|
45 |
text = re.sub(r"\s+", " ", text) # replace multiple whitespace characters with a single space
|
46 |
-
|
47 |
text = re.sub(r'[0-9]', ' ', text) #remove number
|
48 |
-
|
49 |
text = text.lower()
|
50 |
return text
|
51 |
|
@@ -61,21 +73,19 @@ def load_model():
|
|
61 |
bert = BertModel.from_pretrained(bert_path)
|
62 |
|
63 |
# Load the model
|
64 |
-
|
65 |
bert,
|
66 |
HIDDEN_DIM,
|
67 |
OUTPUT_DIM,
|
68 |
N_LAYERS, BIDIRECTIONAL,
|
69 |
DROPOUT)
|
70 |
-
|
71 |
bert,
|
72 |
OUTPUT_DIM)
|
73 |
-
return
|
74 |
|
75 |
-
|
76 |
-
def predict(text, model, tokenizer, device):
|
77 |
|
78 |
-
# model = torch.load(model_path, map_location=device)
|
79 |
if device.type == 'cuda':
|
80 |
model.cuda()
|
81 |
|
@@ -102,38 +112,140 @@ def predict(text, model, tokenizer, device):
|
|
102 |
print("output ", predictions)
|
103 |
return predictions.item()
|
104 |
|
105 |
-
def
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
# st.info("Prediction with ML")
|
117 |
-
|
118 |
-
input_text = st.text_area("Enter Text Here", placeholder="Type Here")
|
119 |
-
all_ml_models = ["IndoBERT", "IndoBERT-BiLSTM"]
|
120 |
-
model_choice = st.selectbox("Select Model", all_ml_models)
|
121 |
-
|
122 |
-
tokenizer = load_tokenizer(bert_path)
|
123 |
-
device = get_device()
|
124 |
-
model1, model2 = load_model()
|
125 |
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
prediction =
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
try:
|
2 |
+
import torch
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
import streamlit as st
|
6 |
+
import re
|
7 |
+
import streamlit as st
|
8 |
+
from transformers import BertTokenizer, BertModel
|
9 |
+
from model import IndoBERTBiLSTM, IndoBERTModel
|
10 |
+
except Exception as e:
|
11 |
+
print(e)
|
12 |
+
|
13 |
+
STYLE = """
|
14 |
+
<style>
|
15 |
+
img {
|
16 |
+
max-width: 100%;
|
17 |
+
}
|
18 |
+
</style>
|
19 |
+
"""
|
20 |
# Config
|
21 |
MAX_SEQ_LEN = 128
|
22 |
+
bert_path = './local/base-indobert'
|
23 |
+
# bert_path = 'indolem/indobert-base-uncased'
|
24 |
+
# MODELS_PATH = ["kadabengaran/IndoBERT-Useful-App-Review",
|
25 |
+
# "kadabengaran/IndoBERT-BiLSTM-Useful-App-Review"]
|
26 |
+
MODELS_PATH = ["./local/indobert1",
|
27 |
+
"./local/indobert2"]
|
28 |
+
|
29 |
+
MODELS_NAME = ["IndoBERT-BiLSTM", "IndoBERT"]
|
30 |
+
LABELS = {'Not Useful': 0, 'Useful': 1}
|
31 |
# "kadabengaran/IndoBERT-BiLSTM-Useful-App-Review"]
|
32 |
HIDDEN_DIM = 768
|
33 |
OUTPUT_DIM = 2 # 2 if Binary
|
|
|
55 |
|
56 |
|
57 |
def remove_special_characters(text):
|
|
|
|
|
58 |
text = re.sub(r'[^a-zA-Z0-9\s]', '', text)
|
|
|
|
|
59 |
text = re.sub(r"\s+", " ", text) # replace multiple whitespace characters with a single space
|
|
|
60 |
text = re.sub(r'[0-9]', ' ', text) #remove number
|
|
|
61 |
text = text.lower()
|
62 |
return text
|
63 |
|
|
|
73 |
bert = BertModel.from_pretrained(bert_path)
|
74 |
|
75 |
# Load the model
|
76 |
+
model_combined = IndoBERTBiLSTM.from_pretrained(MODELS_PATH[0],
|
77 |
bert,
|
78 |
HIDDEN_DIM,
|
79 |
OUTPUT_DIM,
|
80 |
N_LAYERS, BIDIRECTIONAL,
|
81 |
DROPOUT)
|
82 |
+
model_base = IndoBERTModel.from_pretrained(MODELS_PATH[1],
|
83 |
bert,
|
84 |
OUTPUT_DIM)
|
85 |
+
return model_combined, model_base
|
86 |
|
87 |
+
def predict_single(text, model, tokenizer, device):
|
|
|
88 |
|
|
|
89 |
if device.type == 'cuda':
|
90 |
model.cuda()
|
91 |
|
|
|
112 |
print("output ", predictions)
|
113 |
return predictions.item()
|
114 |
|
115 |
+
def predict_multiple(data, model, tokenizer, device):
|
116 |
+
input_ids = []
|
117 |
+
attention_masks = []
|
118 |
+
for row in data.tolist():
|
119 |
+
# Apply remove_special_characters function to title column
|
120 |
+
text = remove_special_characters(row)
|
121 |
+
text = preprocess(text, tokenizer)
|
122 |
+
input_ids.append(text['input_ids'])
|
123 |
+
attention_masks.append(text['attention_mask'])
|
124 |
+
|
125 |
+
predictions = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
+
with torch.no_grad():
|
128 |
+
for i in range(len(input_ids)):
|
129 |
+
test_ids = input_ids[i]
|
130 |
+
test_attention_mask = attention_masks[i]
|
131 |
+
outputs = model(test_ids.to(device), test_attention_mask.to(device))
|
132 |
+
prediction = torch.argmax(outputs, dim= -1)
|
133 |
+
prediction_label = get_key(prediction.item(), LABELS)
|
134 |
+
predictions.append(prediction_label)
|
135 |
+
|
136 |
+
return predictions
|
137 |
+
|
138 |
+
tab_labels = ["Single Input", "Multiple Input"]
|
139 |
+
class App:
|
140 |
+
|
141 |
+
print("Loading All")
|
142 |
+
def __init__(self):
|
143 |
+
self.fileTypes = ["csv"]
|
144 |
+
self.default_tab_selected = tab_labels[0]
|
145 |
+
self.input_text = None
|
146 |
+
self.input_file = None
|
147 |
+
|
148 |
+
def run(self):
|
149 |
+
self.init_session_state() # Initialize session state
|
150 |
+
tokenizer = load_tokenizer(bert_path)
|
151 |
+
device = get_device()
|
152 |
+
model_combined, model_base = load_model()
|
153 |
+
"""App Review Classifier"""
|
154 |
+
html_temp = """
|
155 |
+
<div style="background-color:blue;padding:10px">
|
156 |
+
<h1 style="color:white;text-align:center;">Klasifikasi Ulasan Aplikasi yang Berguna</h1>
|
157 |
+
</div>
|
158 |
+
"""
|
159 |
+
st.markdown(html_temp, unsafe_allow_html=True)
|
160 |
+
self.render_tabs()
|
161 |
+
st.divider()
|
162 |
+
model_choice = self.render_model_selection()
|
163 |
+
if model_choice:
|
164 |
+
if model_choice == MODELS_NAME[0]:
|
165 |
+
model = model_combined
|
166 |
+
elif model_choice == MODELS_NAME[1]:
|
167 |
+
model = model_base
|
168 |
+
self.render_process_button(model, tokenizer, device)
|
169 |
+
|
170 |
+
def init_session_state(self):
|
171 |
+
if "tab_selected" not in st.session_state:
|
172 |
+
st.session_state.tab_selected = tab_labels[0]
|
173 |
+
|
174 |
+
def render_model_selection(self):
|
175 |
+
model_choice = st.selectbox("Select Model", MODELS_NAME)
|
176 |
+
return model_choice
|
177 |
+
|
178 |
+
def render_tabs(self):
|
179 |
+
tab_selected = st.session_state.get('tab_selected', self.default_tab_selected)
|
180 |
+
tab_selected = st.sidebar.radio("Select Input Type", tab_labels)
|
181 |
+
# tab1, tab2 = st.tabs(tab_labels)
|
182 |
+
|
183 |
+
if tab_selected == tab_labels[0]:
|
184 |
+
self.render_single_input()
|
185 |
+
elif tab_selected == tab_labels[1]:
|
186 |
+
self.render_multiple_input()
|
187 |
+
|
188 |
+
st.session_state.tab_selected = tab_selected
|
189 |
+
|
190 |
+
def render_single_input(self):
|
191 |
+
self.input_text = st.text_area("Enter Text Here", placeholder="Type Here")
|
192 |
+
|
193 |
+
def render_multiple_input(self):
|
194 |
+
"""
|
195 |
+
Upload File
|
196 |
+
"""
|
197 |
+
st.markdown(STYLE, unsafe_allow_html=True)
|
198 |
+
file = st.file_uploader("Upload file", type=self.fileTypes)
|
199 |
+
|
200 |
+
|
201 |
+
if not file:
|
202 |
+
st.info("Please upload a file of type: " + ", ".join(self.fileTypes))
|
203 |
+
return
|
204 |
+
|
205 |
+
data = pd.read_csv(file)
|
206 |
+
|
207 |
+
placeholder = st.empty()
|
208 |
+
placeholder.dataframe(data.head(10))
|
209 |
+
|
210 |
+
|
211 |
+
header_list = data.columns.tolist()
|
212 |
+
header_list.insert(0, "---------- select column -------------")
|
213 |
+
ques = st.radio("Select column to process", header_list, index=0)
|
214 |
+
|
215 |
+
if header_list.index(ques) == 0:
|
216 |
+
st.warning("Please select a column to process")
|
217 |
+
return
|
218 |
+
|
219 |
+
df_process = data[ques]
|
220 |
+
self.input_file = data
|
221 |
+
self.process_file = df_process
|
222 |
+
|
223 |
+
def render_process_button(self, model, tokenizer, device):
|
224 |
+
if st.button("Process"):
|
225 |
+
if st.session_state.tab_selected == tab_labels[0]:
|
226 |
+
input_text = self.input_text
|
227 |
+
if input_text:
|
228 |
+
prediction = predict_single(input_text, model, tokenizer, device)
|
229 |
+
prediction_label = get_key(prediction, LABELS)
|
230 |
+
st.write("Prediction:", prediction_label)
|
231 |
+
elif st.session_state.tab_selected == tab_labels[1]:
|
232 |
+
df_process = self.process_file
|
233 |
+
if df_process is not None:
|
234 |
+
prediction = predict_multiple(df_process, model, tokenizer, device)
|
235 |
+
|
236 |
+
st.divider()
|
237 |
+
st.write("Classification Result")
|
238 |
+
input_file = self.input_file
|
239 |
+
input_file["classification_result"] = prediction
|
240 |
+
st.dataframe(input_file.head(10))
|
241 |
+
st.download_button(
|
242 |
+
label="Download Result",
|
243 |
+
data=input_file.to_csv().encode("utf-8"),
|
244 |
+
file_name="classification_result.csv",
|
245 |
+
mime="text/csv",
|
246 |
+
)
|
247 |
+
|
248 |
+
|
249 |
+
if __name__ == "__main__":
|
250 |
+
app = App()
|
251 |
+
app.run()
|