xuandin commited on
Commit
042e3b2
·
verified ·
1 Parent(s): 7725101

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -33
app.py CHANGED
@@ -12,12 +12,14 @@ import psutil
12
  import gc
13
  import threading
14
  from queue import Queue
15
- from concurrent.futures import ThreadPoolExecutor
16
 
17
  # Set environment variables to optimize CPU performance
18
  os.environ["OMP_NUM_THREADS"] = str(psutil.cpu_count(logical=False))
19
  os.environ["MKL_NUM_THREADS"] = str(psutil.cpu_count(logical=False))
20
 
 
 
 
21
  # Load models with caching and CPU optimization
22
  @st.cache_resource()
23
  def load_model(model_name, model_class, is_bc=False, device=None):
@@ -43,9 +45,6 @@ def load_model(model_name, model_class, is_bc=False, device=None):
43
  model.to(device)
44
  return tokenizer, model
45
 
46
- # Set device globally
47
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
48
-
49
  # Pre-process text function to avoid doing it multiple times
50
  @st.cache_data
51
  def preprocess_text(text):
@@ -271,41 +270,19 @@ with st.container():
271
  os.environ["OMP_NUM_THREADS"] = str(num_threads)
272
  os.environ["MKL_NUM_THREADS"] = str(num_threads)
273
 
274
- # Store verification history
275
- if 'history' not in st.session_state:
276
- st.session_state.history = []
277
- if 'latest_result' not in st.session_state:
278
- st.session_state.latest_result = None
279
  if 'models_loaded' not in st.session_state:
280
- st.session_state.models_loaded = False
281
-
282
- # Load the selected models - only reload if model selection changes
283
- if not st.session_state.models_loaded or 'prev_models' not in st.session_state or (
284
- st.session_state.prev_models['qatc'] != qatc_model_name or
285
- st.session_state.prev_models['bc'] != bc_model_name or
286
- st.session_state.prev_models['tc'] != tc_model_name):
287
-
288
  with st.spinner("Loading models..."):
289
- # Clear memory before loading new models
290
- gc.collect()
291
- if DEVICE == "cpu":
292
- torch.set_num_threads(num_threads)
293
-
294
  tokenizer_qatc, model_qatc = load_model(qatc_model_name, QATCForQuestionAnswering, device=DEVICE)
295
  tokenizer_bc, model_bc = load_model(bc_model_name, ClaimModelForClassification, is_bc=True, device=DEVICE)
296
  tokenizer_tc, model_tc = load_model(tc_model_name, ClaimModelForClassification, device=DEVICE)
297
-
298
- st.session_state.prev_models = {
299
- 'qatc': qatc_model_name,
300
- 'bc': bc_model_name,
301
- 'tc': tc_model_name
302
- }
303
  st.session_state.models_loaded = True
304
- else:
305
- # Reuse already loaded models
306
- tokenizer_qatc, model_qatc = load_model(qatc_model_name, QATCForQuestionAnswering, device=DEVICE)
307
- tokenizer_bc, model_bc = load_model(bc_model_name, ClaimModelForClassification, is_bc=True, device=DEVICE)
308
- tokenizer_tc, model_tc = load_model(tc_model_name, ClaimModelForClassification, device=DEVICE)
 
309
 
310
  # Icons for results
311
  verdict_icons = {
 
12
  import gc
13
  import threading
14
  from queue import Queue
 
15
 
16
  # Set environment variables to optimize CPU performance
17
  os.environ["OMP_NUM_THREADS"] = str(psutil.cpu_count(logical=False))
18
  os.environ["MKL_NUM_THREADS"] = str(psutil.cpu_count(logical=False))
19
 
20
+ # Set device globally
21
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
22
+
23
  # Load models with caching and CPU optimization
24
  @st.cache_resource()
25
  def load_model(model_name, model_class, is_bc=False, device=None):
 
45
  model.to(device)
46
  return tokenizer, model
47
 
 
 
 
48
  # Pre-process text function to avoid doing it multiple times
49
  @st.cache_data
50
  def preprocess_text(text):
 
270
  os.environ["OMP_NUM_THREADS"] = str(num_threads)
271
  os.environ["MKL_NUM_THREADS"] = str(num_threads)
272
 
273
+ # Load models once and keep them in memory
 
 
 
 
274
  if 'models_loaded' not in st.session_state:
 
 
 
 
 
 
 
 
275
  with st.spinner("Loading models..."):
 
 
 
 
 
276
  tokenizer_qatc, model_qatc = load_model(qatc_model_name, QATCForQuestionAnswering, device=DEVICE)
277
  tokenizer_bc, model_bc = load_model(bc_model_name, ClaimModelForClassification, is_bc=True, device=DEVICE)
278
  tokenizer_tc, model_tc = load_model(tc_model_name, ClaimModelForClassification, device=DEVICE)
 
 
 
 
 
 
279
  st.session_state.models_loaded = True
280
+
281
+ # Store verification history
282
+ if 'history' not in st.session_state:
283
+ st.session_state.history = []
284
+ if 'latest_result' not in st.session_state:
285
+ st.session_state.latest_result = None
286
 
287
  # Icons for results
288
  verdict_icons = {