xuandin commited on
Commit
99fbeb9
·
verified ·
1 Parent(s): 4af1026

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -45
app.py CHANGED
@@ -8,14 +8,76 @@ from semviqa.tvc.tvc_eval import classify_claim
8
  import time
9
  import pandas as pd
10
 
11
- # Load models with caching
12
  @st.cache_resource()
13
- def load_model(model_name, model_class, is_bc=False):
 
 
 
14
  tokenizer = AutoTokenizer.from_pretrained(model_name)
15
  model = model_class.from_pretrained(model_name, num_labels=3 if not is_bc else 2)
16
  model.eval()
 
 
 
 
 
 
 
17
  return tokenizer, model
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  # Set up page configuration
20
  st.set_page_config(page_title="SemViQA Demo", layout="wide")
21
 
@@ -166,11 +228,31 @@ with st.container():
166
  st.session_state.history = []
167
  if 'latest_result' not in st.session_state:
168
  st.session_state.latest_result = None
 
 
169
 
170
- # Load the selected models
171
- tokenizer_qatc, model_qatc = load_model(qatc_model_name, QATCForQuestionAnswering)
172
- tokenizer_bc, model_bc = load_model(bc_model_name, ClaimModelForClassification, is_bc=True)
173
- tokenizer_tc, model_tc = load_model(tc_model_name, ClaimModelForClassification)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
  # Icons for results
176
  verdict_icons = {
@@ -196,52 +278,41 @@ with st.container():
196
  with col_result:
197
  st.markdown("<h3>Verification Result</h3>", unsafe_allow_html=True)
198
  if verify_button:
 
 
 
 
199
  # Placeholder for displaying result/loading
200
  with st.spinner("Verifying..."):
201
- start_time = time.time()
202
-
203
- # Extract evidence
204
- evidence_start_time = time.time()
205
- with torch.no_grad():
206
- evidence = extract_evidence_tfidf_qatc(
207
- claim, context, model_qatc, tokenizer_qatc,
208
- "cuda" if torch.cuda.is_available() else "cpu",
209
- confidence_threshold=tfidf_threshold,
210
- length_ratio_threshold=length_ratio_threshold
211
- )
212
- evidence_time = time.time() - evidence_start_time
213
-
214
- # Classify the claim
215
- verdict = "NEI"
216
  details = ""
217
- verdict_start_time = time.time()
218
- with torch.no_grad():
219
- prob2class, pred_bc = 0, "Not used"
220
- prob3class, pred_tc = classify_claim(
221
- claim, evidence, model_tc, tokenizer_tc,
222
- "cuda" if torch.cuda.is_available() else "cpu"
223
- )
224
- if pred_tc != 0:
225
- prob2class, pred_bc = classify_claim(
226
- claim, evidence, model_bc, tokenizer_bc,
227
- "cuda" if torch.cuda.is_available() else "cpu"
228
- )
229
- verdict = "SUPPORTED" if pred_bc == 0 else "REFUTED" if prob2class > prob3class else ["NEI", "SUPPORTED", "REFUTED"][pred_tc]
230
- verdict_time = time.time() - verdict_start_time
231
  if show_details:
232
  details = f"""
233
- 3-Class Probability: {prob3class.item():.2f}
234
- 3-Class Predicted Label: {['NEI', 'SUPPORTED', 'REFUTED'][pred_tc]}
235
- 2-Class Probability: {prob2class.item():.2f}
236
- 2-Class Predicted Label: {['SUPPORTED', 'REFUTED'][pred_bc]}
237
  """
238
- total_time = time.time() - start_time
239
  st.session_state.latest_result = {
240
  "claim": claim,
241
- "evidence": evidence,
242
- "verdict": verdict,
243
- "evidence_time": evidence_time,
244
- "verdict_time": verdict_time,
245
  "total_time": total_time,
246
  "details": details,
247
  "qatc_model": qatc_model_name,
@@ -249,9 +320,10 @@ with st.container():
249
  "tc_model": tc_model_name
250
  }
251
 
252
- # Thêm kết quả mới vào lịch sử
253
  st.session_state.history.append(st.session_state.latest_result)
254
 
 
255
  if torch.cuda.is_available():
256
  torch.cuda.empty_cache()
257
 
@@ -264,8 +336,10 @@ with st.container():
264
  <p class='verdict'><span class='verdict-icon'>{verdict_icons.get(res['verdict'], '')}</span>{res['verdict']}</p>
265
  <p><strong>Evidence Inference Time:</strong> {res['evidence_time']:.2f} seconds</p>
266
  <p><strong>Verdict Inference Time:</strong> {res['verdict_time']:.2f} seconds</p>
 
267
  </div>
268
  """, unsafe_allow_html=True)
 
269
  # Download Verification Result Feature
270
  result_text = f"Claim: {res['claim']}\nEvidence: {res['evidence']}\nVerdict: {res['verdict']}\nDetails: {res['details']}"
271
  st.download_button("Download Result", data=result_text, file_name="verification_result.txt", mime="text/plain")
 
8
  import time
9
  import pandas as pd
10
 
11
+ # Load models with caching and optimization
12
  @st.cache_resource()
13
+ def load_model(model_name, model_class, is_bc=False, device=None):
14
+ if device is None:
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+
17
  tokenizer = AutoTokenizer.from_pretrained(model_name)
18
  model = model_class.from_pretrained(model_name, num_labels=3 if not is_bc else 2)
19
  model.eval()
20
+ model.to(device)
21
+
22
+ # Enable CUDA optimizations if available
23
+ if device == "cuda":
24
+ if hasattr(model, 'half') and not model_name.startswith("SemViQA/bc-erniem") and not model_name.startswith("SemViQA/tc-erniem"):
25
+ model = model.half() # Use FP16 for most models (except ERNIE which might not support it)
26
+
27
  return tokenizer, model
28
 
29
+ # Set device globally
30
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
31
+
32
+ # Pre-process text function to avoid doing it multiple times
33
+ @st.cache_data
34
+ def preprocess_text(text):
35
+ # Add any text cleaning or normalization here
36
+ return text.strip()
37
+
38
+ # Optimized function for evidence extraction and classification
39
+ def perform_verification(claim, context, model_qatc, tokenizer_qatc, model_tc, tokenizer_tc,
40
+ model_bc, tokenizer_bc, tfidf_threshold, length_ratio_threshold):
41
+ with torch.no_grad():
42
+ # Extract evidence
43
+ evidence_start_time = time.time()
44
+ evidence = extract_evidence_tfidf_qatc(
45
+ claim, context, model_qatc, tokenizer_qatc,
46
+ DEVICE,
47
+ confidence_threshold=tfidf_threshold,
48
+ length_ratio_threshold=length_ratio_threshold
49
+ )
50
+ evidence_time = time.time() - evidence_start_time
51
+
52
+ # Classify the claim
53
+ verdict_start_time = time.time()
54
+ prob3class, pred_tc = classify_claim(
55
+ claim, evidence, model_tc, tokenizer_tc, DEVICE
56
+ )
57
+
58
+ # Only run binary classifier if needed
59
+ prob2class, pred_bc = 0, "Not used"
60
+ if pred_tc != 0:
61
+ prob2class, pred_bc = classify_claim(
62
+ claim, evidence, model_bc, tokenizer_bc, DEVICE
63
+ )
64
+ verdict = "SUPPORTED" if pred_bc == 0 else "REFUTED" if prob2class > prob3class else ["NEI", "SUPPORTED", "REFUTED"][pred_tc]
65
+ else:
66
+ verdict = "NEI"
67
+
68
+ verdict_time = time.time() - verdict_start_time
69
+
70
+ return {
71
+ "evidence": evidence,
72
+ "verdict": verdict,
73
+ "evidence_time": evidence_time,
74
+ "verdict_time": verdict_time,
75
+ "prob3class": prob3class,
76
+ "pred_tc": pred_tc,
77
+ "prob2class": prob2class,
78
+ "pred_bc": pred_bc
79
+ }
80
+
81
  # Set up page configuration
82
  st.set_page_config(page_title="SemViQA Demo", layout="wide")
83
 
 
228
  st.session_state.history = []
229
  if 'latest_result' not in st.session_state:
230
  st.session_state.latest_result = None
231
+ if 'models_loaded' not in st.session_state:
232
+ st.session_state.models_loaded = False
233
 
234
+ # Load the selected models - only reload if model selection changes
235
+ if not st.session_state.models_loaded or 'prev_models' not in st.session_state or (
236
+ st.session_state.prev_models['qatc'] != qatc_model_name or
237
+ st.session_state.prev_models['bc'] != bc_model_name or
238
+ st.session_state.prev_models['tc'] != tc_model_name):
239
+
240
+ with st.spinner("Loading models..."):
241
+ tokenizer_qatc, model_qatc = load_model(qatc_model_name, QATCForQuestionAnswering, device=DEVICE)
242
+ tokenizer_bc, model_bc = load_model(bc_model_name, ClaimModelForClassification, is_bc=True, device=DEVICE)
243
+ tokenizer_tc, model_tc = load_model(tc_model_name, ClaimModelForClassification, device=DEVICE)
244
+
245
+ st.session_state.prev_models = {
246
+ 'qatc': qatc_model_name,
247
+ 'bc': bc_model_name,
248
+ 'tc': tc_model_name
249
+ }
250
+ st.session_state.models_loaded = True
251
+ else:
252
+ # Reuse already loaded models
253
+ tokenizer_qatc, model_qatc = load_model(qatc_model_name, QATCForQuestionAnswering, device=DEVICE)
254
+ tokenizer_bc, model_bc = load_model(bc_model_name, ClaimModelForClassification, is_bc=True, device=DEVICE)
255
+ tokenizer_tc, model_tc = load_model(tc_model_name, ClaimModelForClassification, device=DEVICE)
256
 
257
  # Icons for results
258
  verdict_icons = {
 
278
  with col_result:
279
  st.markdown("<h3>Verification Result</h3>", unsafe_allow_html=True)
280
  if verify_button:
281
+ # Preprocess texts to improve tokenization speed
282
+ preprocessed_claim = preprocess_text(claim)
283
+ preprocessed_context = preprocess_text(context)
284
+
285
  # Placeholder for displaying result/loading
286
  with st.spinner("Verifying..."):
287
+ start_time = time.time()
288
+
289
+ # Use the optimized verification function
290
+ result = perform_verification(
291
+ preprocessed_claim, preprocessed_context,
292
+ model_qatc, tokenizer_qatc,
293
+ model_tc, tokenizer_tc,
294
+ model_bc, tokenizer_bc,
295
+ tfidf_threshold, length_ratio_threshold
296
+ )
297
+
298
+ total_time = time.time() - start_time
299
+
300
+ # Format details if needed
 
301
  details = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
  if show_details:
303
  details = f"""
304
+ 3-Class Probability: {result['prob3class'].item():.2f}
305
+ 3-Class Predicted Label: {['NEI', 'SUPPORTED', 'REFUTED'][result['pred_tc']]}
306
+ 2-Class Probability: {result['prob2class'].item():.2f}
307
+ 2-Class Predicted Label: {['SUPPORTED', 'REFUTED'][result['pred_bc']] if result['pred_tc'] != 0 else 'Not used'}
308
  """
309
+
310
  st.session_state.latest_result = {
311
  "claim": claim,
312
+ "evidence": result['evidence'],
313
+ "verdict": result['verdict'],
314
+ "evidence_time": result['evidence_time'],
315
+ "verdict_time": result['verdict_time'],
316
  "total_time": total_time,
317
  "details": details,
318
  "qatc_model": qatc_model_name,
 
320
  "tc_model": tc_model_name
321
  }
322
 
323
+ # Add new result to history
324
  st.session_state.history.append(st.session_state.latest_result)
325
 
326
+ # Clear GPU cache to free memory
327
  if torch.cuda.is_available():
328
  torch.cuda.empty_cache()
329
 
 
336
  <p class='verdict'><span class='verdict-icon'>{verdict_icons.get(res['verdict'], '')}</span>{res['verdict']}</p>
337
  <p><strong>Evidence Inference Time:</strong> {res['evidence_time']:.2f} seconds</p>
338
  <p><strong>Verdict Inference Time:</strong> {res['verdict_time']:.2f} seconds</p>
339
+ <p><strong>Total Execution Time:</strong> {res['total_time']:.2f} seconds</p>
340
  </div>
341
  """, unsafe_allow_html=True)
342
+
343
  # Download Verification Result Feature
344
  result_text = f"Claim: {res['claim']}\nEvidence: {res['evidence']}\nVerdict: {res['verdict']}\nDetails: {res['details']}"
345
  st.download_button("Download Result", data=result_text, file_name="verification_result.txt", mime="text/plain")