kadabengaran commited on
Commit
b2d3878
·
1 Parent(s): 99c412b

fix csv input

Browse files
Files changed (1) hide show
  1. app/main.py +6 -5
app/main.py CHANGED
@@ -140,7 +140,8 @@ class App:
140
  self.fileTypes = ["csv"]
141
  self.default_tab_selected = tab_labels[0]
142
  self.input_text = None
143
- self.input_file = None
 
144
 
145
  def run(self):
146
  self.init_session_state() # Initialize session state
@@ -214,8 +215,8 @@ class App:
214
  return
215
 
216
  df_process = data[ques]
217
- self.input_file = data
218
- self.process_file = df_process
219
 
220
  def render_process_button(self, model, tokenizer, device):
221
  if st.button("Process"):
@@ -226,13 +227,13 @@ class App:
226
  prediction_label = get_key(prediction, LABELS)
227
  st.write("Prediction:", prediction_label)
228
  elif st.session_state.tab_selected == tab_labels[1]:
229
- df_process = self.process_file
230
  if df_process is not None:
231
  prediction = predict_multiple(df_process, model, tokenizer, device)
232
 
233
  st.divider()
234
  st.write("Classification Result")
235
- input_file = self.input_file
236
  input_file["classification_result"] = prediction
237
  st.dataframe(input_file.head(10))
238
  st.download_button(
 
140
  self.fileTypes = ["csv"]
141
  self.default_tab_selected = tab_labels[0]
142
  self.input_text = None
143
+ self.csv_input = None
144
+ self.csv_process = None
145
 
146
  def run(self):
147
  self.init_session_state() # Initialize session state
 
215
  return
216
 
217
  df_process = data[ques]
218
+ self.csv_input = data
219
+ self.csv_process = df_process
220
 
221
  def render_process_button(self, model, tokenizer, device):
222
  if st.button("Process"):
 
227
  prediction_label = get_key(prediction, LABELS)
228
  st.write("Prediction:", prediction_label)
229
  elif st.session_state.tab_selected == tab_labels[1]:
230
+ df_process = self.csv_process
231
  if df_process is not None:
232
  prediction = predict_multiple(df_process, model, tokenizer, device)
233
 
234
  st.divider()
235
  st.write("Classification Result")
236
+ input_file = self.csv_input
237
  input_file["classification_result"] = prediction
238
  st.dataframe(input_file.head(10))
239
  st.download_button(