Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -8,14 +8,76 @@ from semviqa.tvc.tvc_eval import classify_claim
|
|
8 |
import time
|
9 |
import pandas as pd
|
10 |
|
11 |
-
# Load models with caching
|
12 |
@st.cache_resource()
|
13 |
-
def load_model(model_name, model_class, is_bc=False):
|
|
|
|
|
|
|
14 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
15 |
model = model_class.from_pretrained(model_name, num_labels=3 if not is_bc else 2)
|
16 |
model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
return tokenizer, model
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
# Set up page configuration
|
20 |
st.set_page_config(page_title="SemViQA Demo", layout="wide")
|
21 |
|
@@ -166,11 +228,31 @@ with st.container():
|
|
166 |
st.session_state.history = []
|
167 |
if 'latest_result' not in st.session_state:
|
168 |
st.session_state.latest_result = None
|
|
|
|
|
169 |
|
170 |
-
# Load the selected models
|
171 |
-
|
172 |
-
|
173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
|
175 |
# Icons for results
|
176 |
verdict_icons = {
|
@@ -196,52 +278,41 @@ with st.container():
|
|
196 |
with col_result:
|
197 |
st.markdown("<h3>Verification Result</h3>", unsafe_allow_html=True)
|
198 |
if verify_button:
|
|
|
|
|
|
|
|
|
199 |
# Placeholder for displaying result/loading
|
200 |
with st.spinner("Verifying..."):
|
201 |
-
start_time = time.time()
|
202 |
-
|
203 |
-
#
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
#
|
215 |
-
verdict = "NEI"
|
216 |
details = ""
|
217 |
-
verdict_start_time = time.time()
|
218 |
-
with torch.no_grad():
|
219 |
-
prob2class, pred_bc = 0, "Not used"
|
220 |
-
prob3class, pred_tc = classify_claim(
|
221 |
-
claim, evidence, model_tc, tokenizer_tc,
|
222 |
-
"cuda" if torch.cuda.is_available() else "cpu"
|
223 |
-
)
|
224 |
-
if pred_tc != 0:
|
225 |
-
prob2class, pred_bc = classify_claim(
|
226 |
-
claim, evidence, model_bc, tokenizer_bc,
|
227 |
-
"cuda" if torch.cuda.is_available() else "cpu"
|
228 |
-
)
|
229 |
-
verdict = "SUPPORTED" if pred_bc == 0 else "REFUTED" if prob2class > prob3class else ["NEI", "SUPPORTED", "REFUTED"][pred_tc]
|
230 |
-
verdict_time = time.time() - verdict_start_time
|
231 |
if show_details:
|
232 |
details = f"""
|
233 |
-
3-Class Probability: {prob3class.item():.2f}
|
234 |
-
3-Class Predicted Label: {['NEI', 'SUPPORTED', 'REFUTED'][pred_tc]}
|
235 |
-
2-Class Probability: {prob2class.item():.2f}
|
236 |
-
2-Class Predicted Label: {['SUPPORTED', 'REFUTED'][pred_bc]}
|
237 |
"""
|
238 |
-
|
239 |
st.session_state.latest_result = {
|
240 |
"claim": claim,
|
241 |
-
"evidence": evidence,
|
242 |
-
"verdict": verdict,
|
243 |
-
"evidence_time": evidence_time,
|
244 |
-
"verdict_time": verdict_time,
|
245 |
"total_time": total_time,
|
246 |
"details": details,
|
247 |
"qatc_model": qatc_model_name,
|
@@ -249,9 +320,10 @@ with st.container():
|
|
249 |
"tc_model": tc_model_name
|
250 |
}
|
251 |
|
252 |
-
#
|
253 |
st.session_state.history.append(st.session_state.latest_result)
|
254 |
|
|
|
255 |
if torch.cuda.is_available():
|
256 |
torch.cuda.empty_cache()
|
257 |
|
@@ -264,8 +336,10 @@ with st.container():
|
|
264 |
<p class='verdict'><span class='verdict-icon'>{verdict_icons.get(res['verdict'], '')}</span>{res['verdict']}</p>
|
265 |
<p><strong>Evidence Inference Time:</strong> {res['evidence_time']:.2f} seconds</p>
|
266 |
<p><strong>Verdict Inference Time:</strong> {res['verdict_time']:.2f} seconds</p>
|
|
|
267 |
</div>
|
268 |
""", unsafe_allow_html=True)
|
|
|
269 |
# Download Verification Result Feature
|
270 |
result_text = f"Claim: {res['claim']}\nEvidence: {res['evidence']}\nVerdict: {res['verdict']}\nDetails: {res['details']}"
|
271 |
st.download_button("Download Result", data=result_text, file_name="verification_result.txt", mime="text/plain")
|
|
|
8 |
import time
|
9 |
import pandas as pd
|
10 |
|
11 |
+
# Load models with caching and optimization
|
12 |
@st.cache_resource()
|
13 |
+
def load_model(model_name, model_class, is_bc=False, device=None):
|
14 |
+
if device is None:
|
15 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
16 |
+
|
17 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
18 |
model = model_class.from_pretrained(model_name, num_labels=3 if not is_bc else 2)
|
19 |
model.eval()
|
20 |
+
model.to(device)
|
21 |
+
|
22 |
+
# Enable CUDA optimizations if available
|
23 |
+
if device == "cuda":
|
24 |
+
if hasattr(model, 'half') and not model_name.startswith("SemViQA/bc-erniem") and not model_name.startswith("SemViQA/tc-erniem"):
|
25 |
+
model = model.half() # Use FP16 for most models (except ERNIE which might not support it)
|
26 |
+
|
27 |
return tokenizer, model
|
28 |
|
29 |
+
# Set device globally
|
30 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
31 |
+
|
32 |
+
# Pre-process text function to avoid doing it multiple times
|
33 |
+
@st.cache_data
|
34 |
+
def preprocess_text(text):
|
35 |
+
# Add any text cleaning or normalization here
|
36 |
+
return text.strip()
|
37 |
+
|
38 |
+
# Optimized function for evidence extraction and classification
|
39 |
+
def perform_verification(claim, context, model_qatc, tokenizer_qatc, model_tc, tokenizer_tc,
|
40 |
+
model_bc, tokenizer_bc, tfidf_threshold, length_ratio_threshold):
|
41 |
+
with torch.no_grad():
|
42 |
+
# Extract evidence
|
43 |
+
evidence_start_time = time.time()
|
44 |
+
evidence = extract_evidence_tfidf_qatc(
|
45 |
+
claim, context, model_qatc, tokenizer_qatc,
|
46 |
+
DEVICE,
|
47 |
+
confidence_threshold=tfidf_threshold,
|
48 |
+
length_ratio_threshold=length_ratio_threshold
|
49 |
+
)
|
50 |
+
evidence_time = time.time() - evidence_start_time
|
51 |
+
|
52 |
+
# Classify the claim
|
53 |
+
verdict_start_time = time.time()
|
54 |
+
prob3class, pred_tc = classify_claim(
|
55 |
+
claim, evidence, model_tc, tokenizer_tc, DEVICE
|
56 |
+
)
|
57 |
+
|
58 |
+
# Only run binary classifier if needed
|
59 |
+
prob2class, pred_bc = 0, "Not used"
|
60 |
+
if pred_tc != 0:
|
61 |
+
prob2class, pred_bc = classify_claim(
|
62 |
+
claim, evidence, model_bc, tokenizer_bc, DEVICE
|
63 |
+
)
|
64 |
+
verdict = "SUPPORTED" if pred_bc == 0 else "REFUTED" if prob2class > prob3class else ["NEI", "SUPPORTED", "REFUTED"][pred_tc]
|
65 |
+
else:
|
66 |
+
verdict = "NEI"
|
67 |
+
|
68 |
+
verdict_time = time.time() - verdict_start_time
|
69 |
+
|
70 |
+
return {
|
71 |
+
"evidence": evidence,
|
72 |
+
"verdict": verdict,
|
73 |
+
"evidence_time": evidence_time,
|
74 |
+
"verdict_time": verdict_time,
|
75 |
+
"prob3class": prob3class,
|
76 |
+
"pred_tc": pred_tc,
|
77 |
+
"prob2class": prob2class,
|
78 |
+
"pred_bc": pred_bc
|
79 |
+
}
|
80 |
+
|
81 |
# Set up page configuration
|
82 |
st.set_page_config(page_title="SemViQA Demo", layout="wide")
|
83 |
|
|
|
228 |
st.session_state.history = []
|
229 |
if 'latest_result' not in st.session_state:
|
230 |
st.session_state.latest_result = None
|
231 |
+
if 'models_loaded' not in st.session_state:
|
232 |
+
st.session_state.models_loaded = False
|
233 |
|
234 |
+
# Load the selected models - only reload if model selection changes
|
235 |
+
if not st.session_state.models_loaded or 'prev_models' not in st.session_state or (
|
236 |
+
st.session_state.prev_models['qatc'] != qatc_model_name or
|
237 |
+
st.session_state.prev_models['bc'] != bc_model_name or
|
238 |
+
st.session_state.prev_models['tc'] != tc_model_name):
|
239 |
+
|
240 |
+
with st.spinner("Loading models..."):
|
241 |
+
tokenizer_qatc, model_qatc = load_model(qatc_model_name, QATCForQuestionAnswering, device=DEVICE)
|
242 |
+
tokenizer_bc, model_bc = load_model(bc_model_name, ClaimModelForClassification, is_bc=True, device=DEVICE)
|
243 |
+
tokenizer_tc, model_tc = load_model(tc_model_name, ClaimModelForClassification, device=DEVICE)
|
244 |
+
|
245 |
+
st.session_state.prev_models = {
|
246 |
+
'qatc': qatc_model_name,
|
247 |
+
'bc': bc_model_name,
|
248 |
+
'tc': tc_model_name
|
249 |
+
}
|
250 |
+
st.session_state.models_loaded = True
|
251 |
+
else:
|
252 |
+
# Reuse already loaded models
|
253 |
+
tokenizer_qatc, model_qatc = load_model(qatc_model_name, QATCForQuestionAnswering, device=DEVICE)
|
254 |
+
tokenizer_bc, model_bc = load_model(bc_model_name, ClaimModelForClassification, is_bc=True, device=DEVICE)
|
255 |
+
tokenizer_tc, model_tc = load_model(tc_model_name, ClaimModelForClassification, device=DEVICE)
|
256 |
|
257 |
# Icons for results
|
258 |
verdict_icons = {
|
|
|
278 |
with col_result:
|
279 |
st.markdown("<h3>Verification Result</h3>", unsafe_allow_html=True)
|
280 |
if verify_button:
|
281 |
+
# Preprocess texts to improve tokenization speed
|
282 |
+
preprocessed_claim = preprocess_text(claim)
|
283 |
+
preprocessed_context = preprocess_text(context)
|
284 |
+
|
285 |
# Placeholder for displaying result/loading
|
286 |
with st.spinner("Verifying..."):
|
287 |
+
start_time = time.time()
|
288 |
+
|
289 |
+
# Use the optimized verification function
|
290 |
+
result = perform_verification(
|
291 |
+
preprocessed_claim, preprocessed_context,
|
292 |
+
model_qatc, tokenizer_qatc,
|
293 |
+
model_tc, tokenizer_tc,
|
294 |
+
model_bc, tokenizer_bc,
|
295 |
+
tfidf_threshold, length_ratio_threshold
|
296 |
+
)
|
297 |
+
|
298 |
+
total_time = time.time() - start_time
|
299 |
+
|
300 |
+
# Format details if needed
|
|
|
301 |
details = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
302 |
if show_details:
|
303 |
details = f"""
|
304 |
+
3-Class Probability: {result['prob3class'].item():.2f}
|
305 |
+
3-Class Predicted Label: {['NEI', 'SUPPORTED', 'REFUTED'][result['pred_tc']]}
|
306 |
+
2-Class Probability: {result['prob2class'].item():.2f}
|
307 |
+
2-Class Predicted Label: {['SUPPORTED', 'REFUTED'][result['pred_bc']] if result['pred_tc'] != 0 else 'Not used'}
|
308 |
"""
|
309 |
+
|
310 |
st.session_state.latest_result = {
|
311 |
"claim": claim,
|
312 |
+
"evidence": result['evidence'],
|
313 |
+
"verdict": result['verdict'],
|
314 |
+
"evidence_time": result['evidence_time'],
|
315 |
+
"verdict_time": result['verdict_time'],
|
316 |
"total_time": total_time,
|
317 |
"details": details,
|
318 |
"qatc_model": qatc_model_name,
|
|
|
320 |
"tc_model": tc_model_name
|
321 |
}
|
322 |
|
323 |
+
# Add new result to history
|
324 |
st.session_state.history.append(st.session_state.latest_result)
|
325 |
|
326 |
+
# Clear GPU cache to free memory
|
327 |
if torch.cuda.is_available():
|
328 |
torch.cuda.empty_cache()
|
329 |
|
|
|
336 |
<p class='verdict'><span class='verdict-icon'>{verdict_icons.get(res['verdict'], '')}</span>{res['verdict']}</p>
|
337 |
<p><strong>Evidence Inference Time:</strong> {res['evidence_time']:.2f} seconds</p>
|
338 |
<p><strong>Verdict Inference Time:</strong> {res['verdict_time']:.2f} seconds</p>
|
339 |
+
<p><strong>Total Execution Time:</strong> {res['total_time']:.2f} seconds</p>
|
340 |
</div>
|
341 |
""", unsafe_allow_html=True)
|
342 |
+
|
343 |
# Download Verification Result Feature
|
344 |
result_text = f"Claim: {res['claim']}\nEvidence: {res['evidence']}\nVerdict: {res['verdict']}\nDetails: {res['details']}"
|
345 |
st.download_button("Download Result", data=result_text, file_name="verification_result.txt", mime="text/plain")
|