Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -12,12 +12,14 @@ import psutil
|
|
12 |
import gc
|
13 |
import threading
|
14 |
from queue import Queue
|
15 |
-
from concurrent.futures import ThreadPoolExecutor
|
16 |
|
17 |
# Set environment variables to optimize CPU performance
|
18 |
os.environ["OMP_NUM_THREADS"] = str(psutil.cpu_count(logical=False))
|
19 |
os.environ["MKL_NUM_THREADS"] = str(psutil.cpu_count(logical=False))
|
20 |
|
|
|
|
|
|
|
21 |
# Load models with caching and CPU optimization
|
22 |
@st.cache_resource()
|
23 |
def load_model(model_name, model_class, is_bc=False, device=None):
|
@@ -43,9 +45,6 @@ def load_model(model_name, model_class, is_bc=False, device=None):
|
|
43 |
model.to(device)
|
44 |
return tokenizer, model
|
45 |
|
46 |
-
# Set device globally
|
47 |
-
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
48 |
-
|
49 |
# Pre-process text function to avoid doing it multiple times
|
50 |
@st.cache_data
|
51 |
def preprocess_text(text):
|
@@ -271,41 +270,19 @@ with st.container():
|
|
271 |
os.environ["OMP_NUM_THREADS"] = str(num_threads)
|
272 |
os.environ["MKL_NUM_THREADS"] = str(num_threads)
|
273 |
|
274 |
-
#
|
275 |
-
if 'history' not in st.session_state:
|
276 |
-
st.session_state.history = []
|
277 |
-
if 'latest_result' not in st.session_state:
|
278 |
-
st.session_state.latest_result = None
|
279 |
if 'models_loaded' not in st.session_state:
|
280 |
-
st.session_state.models_loaded = False
|
281 |
-
|
282 |
-
# Load the selected models - only reload if model selection changes
|
283 |
-
if not st.session_state.models_loaded or 'prev_models' not in st.session_state or (
|
284 |
-
st.session_state.prev_models['qatc'] != qatc_model_name or
|
285 |
-
st.session_state.prev_models['bc'] != bc_model_name or
|
286 |
-
st.session_state.prev_models['tc'] != tc_model_name):
|
287 |
-
|
288 |
with st.spinner("Loading models..."):
|
289 |
-
# Clear memory before loading new models
|
290 |
-
gc.collect()
|
291 |
-
if DEVICE == "cpu":
|
292 |
-
torch.set_num_threads(num_threads)
|
293 |
-
|
294 |
tokenizer_qatc, model_qatc = load_model(qatc_model_name, QATCForQuestionAnswering, device=DEVICE)
|
295 |
tokenizer_bc, model_bc = load_model(bc_model_name, ClaimModelForClassification, is_bc=True, device=DEVICE)
|
296 |
tokenizer_tc, model_tc = load_model(tc_model_name, ClaimModelForClassification, device=DEVICE)
|
297 |
-
|
298 |
-
st.session_state.prev_models = {
|
299 |
-
'qatc': qatc_model_name,
|
300 |
-
'bc': bc_model_name,
|
301 |
-
'tc': tc_model_name
|
302 |
-
}
|
303 |
st.session_state.models_loaded = True
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
|
308 |
-
|
|
|
309 |
|
310 |
# Icons for results
|
311 |
verdict_icons = {
|
|
|
12 |
import gc
|
13 |
import threading
|
14 |
from queue import Queue
|
|
|
15 |
|
16 |
# Set environment variables to optimize CPU performance
|
17 |
os.environ["OMP_NUM_THREADS"] = str(psutil.cpu_count(logical=False))
|
18 |
os.environ["MKL_NUM_THREADS"] = str(psutil.cpu_count(logical=False))
|
19 |
|
20 |
+
# Set device globally
|
21 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
22 |
+
|
23 |
# Load models with caching and CPU optimization
|
24 |
@st.cache_resource()
|
25 |
def load_model(model_name, model_class, is_bc=False, device=None):
|
|
|
45 |
model.to(device)
|
46 |
return tokenizer, model
|
47 |
|
|
|
|
|
|
|
48 |
# Pre-process text function to avoid doing it multiple times
|
49 |
@st.cache_data
|
50 |
def preprocess_text(text):
|
|
|
270 |
os.environ["OMP_NUM_THREADS"] = str(num_threads)
|
271 |
os.environ["MKL_NUM_THREADS"] = str(num_threads)
|
272 |
|
273 |
+
# Load models once and keep them in memory
|
|
|
|
|
|
|
|
|
274 |
if 'models_loaded' not in st.session_state:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
275 |
with st.spinner("Loading models..."):
|
|
|
|
|
|
|
|
|
|
|
276 |
tokenizer_qatc, model_qatc = load_model(qatc_model_name, QATCForQuestionAnswering, device=DEVICE)
|
277 |
tokenizer_bc, model_bc = load_model(bc_model_name, ClaimModelForClassification, is_bc=True, device=DEVICE)
|
278 |
tokenizer_tc, model_tc = load_model(tc_model_name, ClaimModelForClassification, device=DEVICE)
|
|
|
|
|
|
|
|
|
|
|
|
|
279 |
st.session_state.models_loaded = True
|
280 |
+
|
281 |
+
# Store verification history
|
282 |
+
if 'history' not in st.session_state:
|
283 |
+
st.session_state.history = []
|
284 |
+
if 'latest_result' not in st.session_state:
|
285 |
+
st.session_state.latest_result = None
|
286 |
|
287 |
# Icons for results
|
288 |
verdict_icons = {
|