xuandin commited on
Commit
7725101
·
verified ·
1 Parent(s): 99fbeb9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -21
app.py CHANGED
@@ -7,8 +7,18 @@ from semviqa.ser.ser_eval import extract_evidence_tfidf_qatc
7
  from semviqa.tvc.tvc_eval import classify_claim
8
  import time
9
  import pandas as pd
 
 
 
 
 
 
10
 
11
- # Load models with caching and optimization
 
 
 
 
12
  @st.cache_resource()
13
  def load_model(model_name, model_class, is_bc=False, device=None):
14
  if device is None:
@@ -17,13 +27,20 @@ def load_model(model_name, model_class, is_bc=False, device=None):
17
  tokenizer = AutoTokenizer.from_pretrained(model_name)
18
  model = model_class.from_pretrained(model_name, num_labels=3 if not is_bc else 2)
19
  model.eval()
20
- model.to(device)
21
 
22
- # Enable CUDA optimizations if available
23
- if device == "cuda":
24
- if hasattr(model, 'half') and not model_name.startswith("SemViQA/bc-erniem") and not model_name.startswith("SemViQA/tc-erniem"):
25
- model = model.half() # Use FP16 for most models (except ERNIE which might not support it)
 
 
 
 
 
 
 
26
 
 
27
  return tokenizer, model
28
 
29
  # Set device globally
@@ -35,28 +52,52 @@ def preprocess_text(text):
35
  # Add any text cleaning or normalization here
36
  return text.strip()
37
 
38
- # Optimized function for evidence extraction and classification
39
- def perform_verification(claim, context, model_qatc, tokenizer_qatc, model_tc, tokenizer_tc,
40
- model_bc, tokenizer_bc, tfidf_threshold, length_ratio_threshold):
 
41
  with torch.no_grad():
42
- # Extract evidence
43
- evidence_start_time = time.time()
44
  evidence = extract_evidence_tfidf_qatc(
45
  claim, context, model_qatc, tokenizer_qatc,
46
- DEVICE,
47
  confidence_threshold=tfidf_threshold,
48
  length_ratio_threshold=length_ratio_threshold
49
  )
50
- evidence_time = time.time() - evidence_start_time
 
 
 
 
 
 
 
51
 
52
- # Classify the claim
53
- verdict_start_time = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  prob3class, pred_tc = classify_claim(
55
  claim, evidence, model_tc, tokenizer_tc, DEVICE
56
  )
57
 
58
  # Only run binary classifier if needed
59
- prob2class, pred_bc = 0, "Not used"
60
  if pred_tc != 0:
61
  prob2class, pred_bc = classify_claim(
62
  claim, evidence, model_bc, tokenizer_bc, DEVICE
@@ -65,7 +106,7 @@ def perform_verification(claim, context, model_qatc, tokenizer_qatc, model_tc, t
65
  else:
66
  verdict = "NEI"
67
 
68
- verdict_time = time.time() - verdict_start_time
69
 
70
  return {
71
  "evidence": evidence,
@@ -222,6 +263,13 @@ with st.container():
222
  "SemViQA/tc-erniem-isedsc01"
223
  ])
224
  show_details = st.checkbox("Show Probability Details", value=False)
 
 
 
 
 
 
 
225
 
226
  # Store verification history
227
  if 'history' not in st.session_state:
@@ -238,6 +286,11 @@ with st.container():
238
  st.session_state.prev_models['tc'] != tc_model_name):
239
 
240
  with st.spinner("Loading models..."):
 
 
 
 
 
241
  tokenizer_qatc, model_qatc = load_model(qatc_model_name, QATCForQuestionAnswering, device=DEVICE)
242
  tokenizer_bc, model_bc = load_model(bc_model_name, ClaimModelForClassification, is_bc=True, device=DEVICE)
243
  tokenizer_tc, model_tc = load_model(tc_model_name, ClaimModelForClassification, device=DEVICE)
@@ -286,6 +339,9 @@ with st.container():
286
  with st.spinner("Verifying..."):
287
  start_time = time.time()
288
 
 
 
 
289
  # Use the optimized verification function
290
  result = perform_verification(
291
  preprocessed_claim, preprocessed_context,
@@ -304,7 +360,7 @@ with st.container():
304
  3-Class Probability: {result['prob3class'].item():.2f}
305
  3-Class Predicted Label: {['NEI', 'SUPPORTED', 'REFUTED'][result['pred_tc']]}
306
  2-Class Probability: {result['prob2class'].item():.2f}
307
- 2-Class Predicted Label: {['SUPPORTED', 'REFUTED'][result['pred_bc']] if result['pred_tc'] != 0 else 'Not used'}
308
  """
309
 
310
  st.session_state.latest_result = {
@@ -323,9 +379,8 @@ with st.container():
323
  # Add new result to history
324
  st.session_state.history.append(st.session_state.latest_result)
325
 
326
- # Clear GPU cache to free memory
327
- if torch.cuda.is_available():
328
- torch.cuda.empty_cache()
329
 
330
  # Display the result after verification
331
  res = st.session_state.latest_result
 
7
  from semviqa.tvc.tvc_eval import classify_claim
8
  import time
9
  import pandas as pd
10
+ import os
11
+ import psutil
12
+ import gc
13
+ import threading
14
+ from queue import Queue
15
+ from concurrent.futures import ThreadPoolExecutor
16
 
17
+ # Set environment variables to optimize CPU performance
18
+ os.environ["OMP_NUM_THREADS"] = str(psutil.cpu_count(logical=False))
19
+ os.environ["MKL_NUM_THREADS"] = str(psutil.cpu_count(logical=False))
20
+
21
+ # Load models with caching and CPU optimization
22
  @st.cache_resource()
23
  def load_model(model_name, model_class, is_bc=False, device=None):
24
  if device is None:
 
27
  tokenizer = AutoTokenizer.from_pretrained(model_name)
28
  model = model_class.from_pretrained(model_name, num_labels=3 if not is_bc else 2)
29
  model.eval()
 
30
 
31
+ # CPU-specific optimizations
32
+ if device == "cpu":
33
+ # Use torch's quantization for CPU inference speed boost
34
+ try:
35
+ import torch.quantization
36
+ # Quantize the model to INT8
37
+ model = torch.quantization.quantize_dynamic(
38
+ model, {torch.nn.Linear}, dtype=torch.qint8
39
+ )
40
+ except Exception as e:
41
+ st.warning(f"Quantization failed, using default model: {e}")
42
 
43
+ model.to(device)
44
  return tokenizer, model
45
 
46
  # Set device globally
 
52
  # Add any text cleaning or normalization here
53
  return text.strip()
54
 
55
+ # Function to extract evidence in a separate thread for better CPU utilization
56
+ def extract_evidence_threaded(queue, claim, context, model_qatc, tokenizer_qatc, device,
57
+ tfidf_threshold, length_ratio_threshold):
58
+ start_time = time.time()
59
  with torch.no_grad():
 
 
60
  evidence = extract_evidence_tfidf_qatc(
61
  claim, context, model_qatc, tokenizer_qatc,
62
+ device,
63
  confidence_threshold=tfidf_threshold,
64
  length_ratio_threshold=length_ratio_threshold
65
  )
66
+ evidence_time = time.time() - start_time
67
+ queue.put((evidence, evidence_time))
68
+
69
+ # Function to classify in a separate thread
70
+ def classify_claim_threaded(queue, claim, evidence, model, tokenizer, device):
71
+ with torch.no_grad():
72
+ result = classify_claim(claim, evidence, model, tokenizer, device)
73
+ queue.put(result)
74
 
75
+ # Optimized function for evidence extraction and classification with better CPU performance
76
+ def perform_verification(claim, context, model_qatc, tokenizer_qatc, model_tc, tokenizer_tc,
77
+ model_bc, tokenizer_bc, tfidf_threshold, length_ratio_threshold):
78
+ # Use thread for evidence extraction to allow garbage collection in between
79
+ evidence_queue = Queue()
80
+ evidence_thread = threading.Thread(
81
+ target=extract_evidence_threaded,
82
+ args=(evidence_queue, claim, context, model_qatc, tokenizer_qatc, DEVICE,
83
+ tfidf_threshold, length_ratio_threshold)
84
+ )
85
+ evidence_thread.start()
86
+ evidence_thread.join()
87
+ evidence, evidence_time = evidence_queue.get()
88
+
89
+ # Explicit garbage collection after evidence extraction
90
+ gc.collect()
91
+
92
+ # Classify the claim
93
+ verdict_start_time = time.time()
94
+ with torch.no_grad():
95
  prob3class, pred_tc = classify_claim(
96
  claim, evidence, model_tc, tokenizer_tc, DEVICE
97
  )
98
 
99
  # Only run binary classifier if needed
100
+ prob2class, pred_bc = 0, 0
101
  if pred_tc != 0:
102
  prob2class, pred_bc = classify_claim(
103
  claim, evidence, model_bc, tokenizer_bc, DEVICE
 
106
  else:
107
  verdict = "NEI"
108
 
109
+ verdict_time = time.time() - verdict_start_time
110
 
111
  return {
112
  "evidence": evidence,
 
263
  "SemViQA/tc-erniem-isedsc01"
264
  ])
265
  show_details = st.checkbox("Show Probability Details", value=False)
266
+
267
+ # Add CPU optimization settings
268
+ st.subheader("CPU Performance Settings")
269
+ num_threads = st.slider("Number of CPU Threads", 1, psutil.cpu_count(),
270
+ psutil.cpu_count(logical=False))
271
+ os.environ["OMP_NUM_THREADS"] = str(num_threads)
272
+ os.environ["MKL_NUM_THREADS"] = str(num_threads)
273
 
274
  # Store verification history
275
  if 'history' not in st.session_state:
 
286
  st.session_state.prev_models['tc'] != tc_model_name):
287
 
288
  with st.spinner("Loading models..."):
289
+ # Clear memory before loading new models
290
+ gc.collect()
291
+ if DEVICE == "cpu":
292
+ torch.set_num_threads(num_threads)
293
+
294
  tokenizer_qatc, model_qatc = load_model(qatc_model_name, QATCForQuestionAnswering, device=DEVICE)
295
  tokenizer_bc, model_bc = load_model(bc_model_name, ClaimModelForClassification, is_bc=True, device=DEVICE)
296
  tokenizer_tc, model_tc = load_model(tc_model_name, ClaimModelForClassification, device=DEVICE)
 
339
  with st.spinner("Verifying..."):
340
  start_time = time.time()
341
 
342
+ # Clear memory before verification
343
+ gc.collect()
344
+
345
  # Use the optimized verification function
346
  result = perform_verification(
347
  preprocessed_claim, preprocessed_context,
 
360
  3-Class Probability: {result['prob3class'].item():.2f}
361
  3-Class Predicted Label: {['NEI', 'SUPPORTED', 'REFUTED'][result['pred_tc']]}
362
  2-Class Probability: {result['prob2class'].item():.2f}
363
+ 2-Class Predicted Label: {['SUPPORTED', 'REFUTED'][result['pred_bc']] if isinstance(result['pred_bc'], int) and result['pred_tc'] != 0 else 'Not used'}
364
  """
365
 
366
  st.session_state.latest_result = {
 
379
  # Add new result to history
380
  st.session_state.history.append(st.session_state.latest_result)
381
 
382
+ # Clear memory after processing
383
+ gc.collect()
 
384
 
385
  # Display the result after verification
386
  res = st.session_state.latest_result