Spaces:
Runtime error
Runtime error
Commit
·
b2d3878
1
Parent(s):
99c412b
fix csv input
Browse files- app/main.py +6 -5
app/main.py
CHANGED
@@ -140,7 +140,8 @@ class App:
|
|
140 |
self.fileTypes = ["csv"]
|
141 |
self.default_tab_selected = tab_labels[0]
|
142 |
self.input_text = None
|
143 |
-
self.
|
|
|
144 |
|
145 |
def run(self):
|
146 |
self.init_session_state() # Initialize session state
|
@@ -214,8 +215,8 @@ class App:
|
|
214 |
return
|
215 |
|
216 |
df_process = data[ques]
|
217 |
-
self.
|
218 |
-
self.
|
219 |
|
220 |
def render_process_button(self, model, tokenizer, device):
|
221 |
if st.button("Process"):
|
@@ -226,13 +227,13 @@ class App:
|
|
226 |
prediction_label = get_key(prediction, LABELS)
|
227 |
st.write("Prediction:", prediction_label)
|
228 |
elif st.session_state.tab_selected == tab_labels[1]:
|
229 |
-
df_process = self.
|
230 |
if df_process is not None:
|
231 |
prediction = predict_multiple(df_process, model, tokenizer, device)
|
232 |
|
233 |
st.divider()
|
234 |
st.write("Classification Result")
|
235 |
-
input_file = self.
|
236 |
input_file["classification_result"] = prediction
|
237 |
st.dataframe(input_file.head(10))
|
238 |
st.download_button(
|
|
|
140 |
self.fileTypes = ["csv"]
|
141 |
self.default_tab_selected = tab_labels[0]
|
142 |
self.input_text = None
|
143 |
+
self.csv_input = None
|
144 |
+
self.csv_process = None
|
145 |
|
146 |
def run(self):
|
147 |
self.init_session_state() # Initialize session state
|
|
|
215 |
return
|
216 |
|
217 |
df_process = data[ques]
|
218 |
+
self.csv_input = data
|
219 |
+
self.csv_process = df_process
|
220 |
|
221 |
def render_process_button(self, model, tokenizer, device):
|
222 |
if st.button("Process"):
|
|
|
227 |
prediction_label = get_key(prediction, LABELS)
|
228 |
st.write("Prediction:", prediction_label)
|
229 |
elif st.session_state.tab_selected == tab_labels[1]:
|
230 |
+
df_process = self.csv_process
|
231 |
if df_process is not None:
|
232 |
prediction = predict_multiple(df_process, model, tokenizer, device)
|
233 |
|
234 |
st.divider()
|
235 |
st.write("Classification Result")
|
236 |
+
input_file = self.csv_input
|
237 |
input_file["classification_result"] = prediction
|
238 |
st.dataframe(input_file.head(10))
|
239 |
st.download_button(
|