Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -7,8 +7,18 @@ from semviqa.ser.ser_eval import extract_evidence_tfidf_qatc
|
|
7 |
from semviqa.tvc.tvc_eval import classify_claim
|
8 |
import time
|
9 |
import pandas as pd
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
-
#
|
|
|
|
|
|
|
|
|
12 |
@st.cache_resource()
|
13 |
def load_model(model_name, model_class, is_bc=False, device=None):
|
14 |
if device is None:
|
@@ -17,13 +27,20 @@ def load_model(model_name, model_class, is_bc=False, device=None):
|
|
17 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
18 |
model = model_class.from_pretrained(model_name, num_labels=3 if not is_bc else 2)
|
19 |
model.eval()
|
20 |
-
model.to(device)
|
21 |
|
22 |
-
#
|
23 |
-
if device == "
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
|
|
27 |
return tokenizer, model
|
28 |
|
29 |
# Set device globally
|
@@ -35,28 +52,52 @@ def preprocess_text(text):
|
|
35 |
# Add any text cleaning or normalization here
|
36 |
return text.strip()
|
37 |
|
38 |
-
#
|
39 |
-
def
|
40 |
-
|
|
|
41 |
with torch.no_grad():
|
42 |
-
# Extract evidence
|
43 |
-
evidence_start_time = time.time()
|
44 |
evidence = extract_evidence_tfidf_qatc(
|
45 |
claim, context, model_qatc, tokenizer_qatc,
|
46 |
-
|
47 |
confidence_threshold=tfidf_threshold,
|
48 |
length_ratio_threshold=length_ratio_threshold
|
49 |
)
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
-
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
prob3class, pred_tc = classify_claim(
|
55 |
claim, evidence, model_tc, tokenizer_tc, DEVICE
|
56 |
)
|
57 |
|
58 |
# Only run binary classifier if needed
|
59 |
-
prob2class, pred_bc = 0,
|
60 |
if pred_tc != 0:
|
61 |
prob2class, pred_bc = classify_claim(
|
62 |
claim, evidence, model_bc, tokenizer_bc, DEVICE
|
@@ -65,7 +106,7 @@ def perform_verification(claim, context, model_qatc, tokenizer_qatc, model_tc, t
|
|
65 |
else:
|
66 |
verdict = "NEI"
|
67 |
|
68 |
-
|
69 |
|
70 |
return {
|
71 |
"evidence": evidence,
|
@@ -222,6 +263,13 @@ with st.container():
|
|
222 |
"SemViQA/tc-erniem-isedsc01"
|
223 |
])
|
224 |
show_details = st.checkbox("Show Probability Details", value=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
225 |
|
226 |
# Store verification history
|
227 |
if 'history' not in st.session_state:
|
@@ -238,6 +286,11 @@ with st.container():
|
|
238 |
st.session_state.prev_models['tc'] != tc_model_name):
|
239 |
|
240 |
with st.spinner("Loading models..."):
|
|
|
|
|
|
|
|
|
|
|
241 |
tokenizer_qatc, model_qatc = load_model(qatc_model_name, QATCForQuestionAnswering, device=DEVICE)
|
242 |
tokenizer_bc, model_bc = load_model(bc_model_name, ClaimModelForClassification, is_bc=True, device=DEVICE)
|
243 |
tokenizer_tc, model_tc = load_model(tc_model_name, ClaimModelForClassification, device=DEVICE)
|
@@ -286,6 +339,9 @@ with st.container():
|
|
286 |
with st.spinner("Verifying..."):
|
287 |
start_time = time.time()
|
288 |
|
|
|
|
|
|
|
289 |
# Use the optimized verification function
|
290 |
result = perform_verification(
|
291 |
preprocessed_claim, preprocessed_context,
|
@@ -304,7 +360,7 @@ with st.container():
|
|
304 |
3-Class Probability: {result['prob3class'].item():.2f}
|
305 |
3-Class Predicted Label: {['NEI', 'SUPPORTED', 'REFUTED'][result['pred_tc']]}
|
306 |
2-Class Probability: {result['prob2class'].item():.2f}
|
307 |
-
2-Class Predicted Label: {['SUPPORTED', 'REFUTED'][result['pred_bc']] if result['pred_tc'] != 0 else 'Not used'}
|
308 |
"""
|
309 |
|
310 |
st.session_state.latest_result = {
|
@@ -323,9 +379,8 @@ with st.container():
|
|
323 |
# Add new result to history
|
324 |
st.session_state.history.append(st.session_state.latest_result)
|
325 |
|
326 |
-
# Clear
|
327 |
-
|
328 |
-
torch.cuda.empty_cache()
|
329 |
|
330 |
# Display the result after verification
|
331 |
res = st.session_state.latest_result
|
|
|
7 |
from semviqa.tvc.tvc_eval import classify_claim
|
8 |
import time
|
9 |
import pandas as pd
|
10 |
+
import os
|
11 |
+
import psutil
|
12 |
+
import gc
|
13 |
+
import threading
|
14 |
+
from queue import Queue
|
15 |
+
from concurrent.futures import ThreadPoolExecutor
|
16 |
|
17 |
+
# Set environment variables to optimize CPU performance
|
18 |
+
os.environ["OMP_NUM_THREADS"] = str(psutil.cpu_count(logical=False))
|
19 |
+
os.environ["MKL_NUM_THREADS"] = str(psutil.cpu_count(logical=False))
|
20 |
+
|
21 |
+
# Load models with caching and CPU optimization
|
22 |
@st.cache_resource()
|
23 |
def load_model(model_name, model_class, is_bc=False, device=None):
|
24 |
if device is None:
|
|
|
27 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
28 |
model = model_class.from_pretrained(model_name, num_labels=3 if not is_bc else 2)
|
29 |
model.eval()
|
|
|
30 |
|
31 |
+
# CPU-specific optimizations
|
32 |
+
if device == "cpu":
|
33 |
+
# Use torch's quantization for CPU inference speed boost
|
34 |
+
try:
|
35 |
+
import torch.quantization
|
36 |
+
# Quantize the model to INT8
|
37 |
+
model = torch.quantization.quantize_dynamic(
|
38 |
+
model, {torch.nn.Linear}, dtype=torch.qint8
|
39 |
+
)
|
40 |
+
except Exception as e:
|
41 |
+
st.warning(f"Quantization failed, using default model: {e}")
|
42 |
|
43 |
+
model.to(device)
|
44 |
return tokenizer, model
|
45 |
|
46 |
# Set device globally
|
|
|
52 |
# Add any text cleaning or normalization here
|
53 |
return text.strip()
|
54 |
|
55 |
+
# Function to extract evidence in a separate thread for better CPU utilization
|
56 |
+
def extract_evidence_threaded(queue, claim, context, model_qatc, tokenizer_qatc, device,
|
57 |
+
tfidf_threshold, length_ratio_threshold):
|
58 |
+
start_time = time.time()
|
59 |
with torch.no_grad():
|
|
|
|
|
60 |
evidence = extract_evidence_tfidf_qatc(
|
61 |
claim, context, model_qatc, tokenizer_qatc,
|
62 |
+
device,
|
63 |
confidence_threshold=tfidf_threshold,
|
64 |
length_ratio_threshold=length_ratio_threshold
|
65 |
)
|
66 |
+
evidence_time = time.time() - start_time
|
67 |
+
queue.put((evidence, evidence_time))
|
68 |
+
|
69 |
+
# Function to classify in a separate thread
|
70 |
+
def classify_claim_threaded(queue, claim, evidence, model, tokenizer, device):
|
71 |
+
with torch.no_grad():
|
72 |
+
result = classify_claim(claim, evidence, model, tokenizer, device)
|
73 |
+
queue.put(result)
|
74 |
|
75 |
+
# Optimized function for evidence extraction and classification with better CPU performance
|
76 |
+
def perform_verification(claim, context, model_qatc, tokenizer_qatc, model_tc, tokenizer_tc,
|
77 |
+
model_bc, tokenizer_bc, tfidf_threshold, length_ratio_threshold):
|
78 |
+
# Use thread for evidence extraction to allow garbage collection in between
|
79 |
+
evidence_queue = Queue()
|
80 |
+
evidence_thread = threading.Thread(
|
81 |
+
target=extract_evidence_threaded,
|
82 |
+
args=(evidence_queue, claim, context, model_qatc, tokenizer_qatc, DEVICE,
|
83 |
+
tfidf_threshold, length_ratio_threshold)
|
84 |
+
)
|
85 |
+
evidence_thread.start()
|
86 |
+
evidence_thread.join()
|
87 |
+
evidence, evidence_time = evidence_queue.get()
|
88 |
+
|
89 |
+
# Explicit garbage collection after evidence extraction
|
90 |
+
gc.collect()
|
91 |
+
|
92 |
+
# Classify the claim
|
93 |
+
verdict_start_time = time.time()
|
94 |
+
with torch.no_grad():
|
95 |
prob3class, pred_tc = classify_claim(
|
96 |
claim, evidence, model_tc, tokenizer_tc, DEVICE
|
97 |
)
|
98 |
|
99 |
# Only run binary classifier if needed
|
100 |
+
prob2class, pred_bc = 0, 0
|
101 |
if pred_tc != 0:
|
102 |
prob2class, pred_bc = classify_claim(
|
103 |
claim, evidence, model_bc, tokenizer_bc, DEVICE
|
|
|
106 |
else:
|
107 |
verdict = "NEI"
|
108 |
|
109 |
+
verdict_time = time.time() - verdict_start_time
|
110 |
|
111 |
return {
|
112 |
"evidence": evidence,
|
|
|
263 |
"SemViQA/tc-erniem-isedsc01"
|
264 |
])
|
265 |
show_details = st.checkbox("Show Probability Details", value=False)
|
266 |
+
|
267 |
+
# Add CPU optimization settings
|
268 |
+
st.subheader("CPU Performance Settings")
|
269 |
+
num_threads = st.slider("Number of CPU Threads", 1, psutil.cpu_count(),
|
270 |
+
psutil.cpu_count(logical=False))
|
271 |
+
os.environ["OMP_NUM_THREADS"] = str(num_threads)
|
272 |
+
os.environ["MKL_NUM_THREADS"] = str(num_threads)
|
273 |
|
274 |
# Store verification history
|
275 |
if 'history' not in st.session_state:
|
|
|
286 |
st.session_state.prev_models['tc'] != tc_model_name):
|
287 |
|
288 |
with st.spinner("Loading models..."):
|
289 |
+
# Clear memory before loading new models
|
290 |
+
gc.collect()
|
291 |
+
if DEVICE == "cpu":
|
292 |
+
torch.set_num_threads(num_threads)
|
293 |
+
|
294 |
tokenizer_qatc, model_qatc = load_model(qatc_model_name, QATCForQuestionAnswering, device=DEVICE)
|
295 |
tokenizer_bc, model_bc = load_model(bc_model_name, ClaimModelForClassification, is_bc=True, device=DEVICE)
|
296 |
tokenizer_tc, model_tc = load_model(tc_model_name, ClaimModelForClassification, device=DEVICE)
|
|
|
339 |
with st.spinner("Verifying..."):
|
340 |
start_time = time.time()
|
341 |
|
342 |
+
# Clear memory before verification
|
343 |
+
gc.collect()
|
344 |
+
|
345 |
# Use the optimized verification function
|
346 |
result = perform_verification(
|
347 |
preprocessed_claim, preprocessed_context,
|
|
|
360 |
3-Class Probability: {result['prob3class'].item():.2f}
|
361 |
3-Class Predicted Label: {['NEI', 'SUPPORTED', 'REFUTED'][result['pred_tc']]}
|
362 |
2-Class Probability: {result['prob2class'].item():.2f}
|
363 |
+
2-Class Predicted Label: {['SUPPORTED', 'REFUTED'][result['pred_bc']] if isinstance(result['pred_bc'], int) and result['pred_tc'] != 0 else 'Not used'}
|
364 |
"""
|
365 |
|
366 |
st.session_state.latest_result = {
|
|
|
379 |
# Add new result to history
|
380 |
st.session_state.history.append(st.session_state.latest_result)
|
381 |
|
382 |
+
# Clear memory after processing
|
383 |
+
gc.collect()
|
|
|
384 |
|
385 |
# Display the result after verification
|
386 |
res = st.session_state.latest_result
|