Spaces:
Running
Running
import streamlit as st | |
import torch | |
from transformers import AutoTokenizer | |
from semviqa.ser.qatc_model import QATCForQuestionAnswering | |
from semviqa.tvc.model import ClaimModelForClassification | |
from semviqa.ser.ser_eval import extract_evidence_tfidf_qatc | |
from semviqa.tvc.tvc_eval import classify_claim | |
import time | |
import pandas as pd | |
import os | |
import psutil | |
import gc | |
import numpy as np | |
from functools import lru_cache | |
import threading | |
from concurrent.futures import ThreadPoolExecutor | |
import torch.nn.functional as F | |
# Set environment variables to optimize CPU performance | |
os.environ["OMP_NUM_THREADS"] = str(psutil.cpu_count(logical=False)) | |
os.environ["MKL_NUM_THREADS"] = str(psutil.cpu_count(logical=False)) | |
torch.set_num_threads(psutil.cpu_count(logical=False)) | |
# Set device globally | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
# Cache for model outputs | |
def cached_classify_claim(claim, evidence, model_name, is_bc=False): | |
tokenizer, model = load_model(model_name, ClaimModelForClassification, is_bc=is_bc, device=DEVICE) | |
with torch.no_grad(): | |
prob, pred = classify_claim(claim, evidence, model, tokenizer, DEVICE) | |
return prob, pred | |
# Optimized model loading with caching | |
# Cache for 1 hour | |
def load_model(model_name, model_class, is_bc=False, device=None): | |
if device is None: | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = model_class.from_pretrained(model_name, num_labels=3 if not is_bc else 2) | |
model.eval() | |
# Optimize model for inference | |
if device == "cuda": | |
model = model.half() # Use FP16 for faster inference | |
torch.cuda.empty_cache() | |
model.to(device) | |
return tokenizer, model | |
# Optimized text preprocessing | |
def preprocess_text(text): | |
# Add any text cleaning or normalization here | |
return text.strip() | |
# Batch processing for evidence extraction | |
def batch_extract_evidence(claims, contexts, model_qatc, tokenizer_qatc, batch_size=4): | |
results = [] | |
for i in range(0, len(claims), batch_size): | |
batch_claims = claims[i:i + batch_size] | |
batch_contexts = contexts[i:i + batch_size] | |
with torch.no_grad(): | |
batch_results = [ | |
extract_evidence_tfidf_qatc( | |
claim, context, model_qatc, tokenizer_qatc, | |
DEVICE, | |
confidence_threshold=0.5, | |
length_ratio_threshold=0.5 | |
) | |
for claim, context in zip(batch_claims, batch_contexts) | |
] | |
results.extend(batch_results) | |
return results | |
# Optimized verification function with parallel processing | |
def perform_verification(claim, context, model_qatc, tokenizer_qatc, model_tc, tokenizer_tc, | |
model_bc, tokenizer_bc, tfidf_threshold, length_ratio_threshold): | |
# Extract evidence with optimized settings | |
evidence_start_time = time.time() | |
evidence = extract_evidence_tfidf_qatc( | |
claim, context, model_qatc, tokenizer_qatc, | |
DEVICE, | |
confidence_threshold=tfidf_threshold, | |
length_ratio_threshold=length_ratio_threshold | |
) | |
evidence_time = time.time() - evidence_start_time | |
# Clear memory after evidence extraction | |
if DEVICE == "cuda": | |
torch.cuda.empty_cache() | |
gc.collect() | |
verdict_start_time = time.time() | |
# Parallel classification using ThreadPoolExecutor | |
with ThreadPoolExecutor(max_workers=2) as executor: | |
future_tc = executor.submit(cached_classify_claim, claim, evidence, tc_model_name, False) | |
future_bc = executor.submit(cached_classify_claim, claim, evidence, bc_model_name, True) | |
prob3class, pred_tc = future_tc.result() | |
prob2class, pred_bc = future_bc.result() | |
with torch.no_grad(): | |
verdict = "NEI" | |
if pred_tc != 0: | |
verdict = "SUPPORTED" if pred_bc == 0 else "REFUTED" if prob2class > prob3class else ["NEI", "SUPPORTED", "REFUTED"][pred_tc] | |
verdict_time = time.time() - verdict_start_time | |
return { | |
"evidence": evidence, | |
"verdict": verdict, | |
"evidence_time": evidence_time, | |
"verdict_time": verdict_time, | |
"prob3class": prob3class.item() if isinstance(prob3class, torch.Tensor) else prob3class, | |
"pred_tc": pred_tc, | |
"prob2class": prob2class.item() if isinstance(prob2class, torch.Tensor) else prob2class, | |
"pred_bc": pred_bc | |
} | |
# Add performance monitoring | |
def monitor_performance(): | |
if DEVICE == "cuda": | |
return { | |
"gpu_memory_used": torch.cuda.memory_allocated() / 1024**2, | |
"gpu_memory_cached": torch.cuda.memory_reserved() / 1024**2, | |
"cpu_percent": psutil.cpu_percent(), | |
"memory_percent": psutil.virtual_memory().percent | |
} | |
return { | |
"cpu_percent": psutil.cpu_percent(), | |
"memory_percent": psutil.virtual_memory().percent | |
} | |
# Set page configuration | |
st.set_page_config( | |
page_title="SemViQA - A Semantic Question Answering System for Vietnamese Information Fact-Checking", | |
page_icon="π", | |
layout="wide", | |
initial_sidebar_state="expanded" | |
) | |
# Custom CSS | |
st.markdown(""" | |
<style> | |
/* Main theme colors */ | |
:root { | |
--primary-color: #1f77b4; | |
--secondary-color: #2c3e50; | |
--accent-color: #e74c3c; | |
--background-color: #f8f9fa; | |
--text-color: #2c3e50; | |
--border-color: #e0e0e0; | |
} | |
/* General styling */ | |
.stApp { | |
background-color: var(--background-color); | |
color: var(--text-color); | |
} | |
/* Header styling */ | |
.main-header { | |
background: linear-gradient(135deg, var(--primary-color), var(--secondary-color)); | |
color: white; | |
padding: 2rem; | |
border-radius: 10px; | |
margin-bottom: 2rem; | |
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); | |
} | |
.main-title { | |
font-size: 2.5rem; | |
font-weight: bold; | |
text-align: center; | |
margin-bottom: 1rem; | |
} | |
.sub-title { | |
font-size: 1.2rem; | |
text-align: center; | |
opacity: 0.9; | |
} | |
/* Input styling */ | |
.stTextArea textarea { | |
border: 2px solid var(--border-color); | |
border-radius: 8px; | |
padding: 1rem; | |
font-size: 1rem; | |
min-height: 150px; | |
background-color: white; | |
} | |
/* Button styling */ | |
.stButton>button { | |
background: linear-gradient(135deg, var(--primary-color), var(--secondary-color)); | |
color: white; | |
border: none; | |
border-radius: 8px; | |
padding: 0.8rem 2rem; | |
font-size: 1.1rem; | |
font-weight: bold; | |
transition: all 0.3s ease; | |
} | |
.stButton>button:hover { | |
transform: translateY(-2px); | |
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2); | |
} | |
/* Result box styling */ | |
.result-box { | |
background-color: white; | |
border-radius: 12px; | |
padding: 2rem; | |
margin: 1rem 0; | |
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); | |
} | |
/* Info section styling */ | |
.info-section { | |
background-color: white; | |
border-radius: 12px; | |
padding: 2rem; | |
margin: 1rem 0; | |
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); | |
} | |
.info-section h3 { | |
color: var(--primary-color); | |
font-size: 1.8rem; | |
margin-bottom: 1.5rem; | |
border-bottom: 2px solid var(--border-color); | |
padding-bottom: 0.5rem; | |
} | |
.info-section h4 { | |
color: var(--secondary-color); | |
font-size: 1.4rem; | |
margin: 1.5rem 0 1rem 0; | |
} | |
.info-section p { | |
font-size: 1.1rem; | |
line-height: 1.6; | |
color: var(--text-color); | |
margin-bottom: 1.5rem; | |
} | |
.info-section ol, .info-section ul { | |
margin-left: 1.5rem; | |
margin-bottom: 1.5rem; | |
} | |
.info-section li { | |
font-size: 1.1rem; | |
line-height: 1.6; | |
margin-bottom: 0.5rem; | |
} | |
.info-section strong { | |
color: var(--primary-color); | |
} | |
.verdict { | |
font-size: 1.8rem; | |
font-weight: bold; | |
padding: 1rem; | |
border-radius: 8px; | |
margin: 1rem 0; | |
text-align: center; | |
} | |
.verdict-supported { | |
background-color: #d4edda; | |
color: #155724; | |
} | |
.verdict-refuted { | |
background-color: #f8d7da; | |
color: #721c24; | |
} | |
.verdict-nei { | |
background-color: #fff3cd; | |
color: #856404; | |
} | |
/* Sidebar styling */ | |
.css-1d391kg { | |
background-color: white; | |
padding: 2rem; | |
border-radius: 12px; | |
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1); | |
} | |
/* Stats box styling */ | |
.stats-box { | |
background-color: white; | |
border-radius: 8px; | |
padding: 1rem; | |
margin: 0.5rem 0; | |
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05); | |
} | |
/* Code block styling */ | |
.code-block { | |
background-color: #f8f9fa; | |
border: 1px solid var(--border-color); | |
border-radius: 8px; | |
padding: 1rem; | |
font-family: monospace; | |
margin: 1rem 0; | |
} | |
/* Tab styling */ | |
.stTabs [data-baseweb="tab-list"] { | |
gap: 2rem; | |
} | |
.stTabs [data-baseweb="tab"] { | |
background-color: white; | |
border-radius: 8px; | |
padding: 0.5rem 1rem; | |
margin: 0 0.5rem; | |
} | |
.stTabs [aria-selected="true"] { | |
background-color: var(--primary-color); | |
color: white; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Main header | |
st.markdown(""" | |
<div class="main-header"> | |
<div class="main-title">SemViQA</div> | |
<div class="sub-title">A Semantic Question Answering System for Vietnamese Information Fact-Checking</div> | |
</div> | |
""", unsafe_allow_html=True) | |
# Sidebar | |
with st.sidebar: | |
st.markdown("### βοΈ System Settings") | |
# Model selection | |
st.markdown("#### π§ Model Selection") | |
qatc_model_name = st.selectbox( | |
"QATC Model", | |
[ | |
"SemViQA/qatc-infoxlm-viwikifc", | |
"SemViQA/qatc-infoxlm-isedsc01", | |
"SemViQA/qatc-vimrc-viwikifc", | |
"SemViQA/qatc-vimrc-isedsc01" | |
] | |
) | |
bc_model_name = st.selectbox( | |
"Binary Classification Model", | |
[ | |
"SemViQA/bc-xlmr-viwikifc", | |
"SemViQA/bc-xlmr-isedsc01", | |
"SemViQA/bc-infoxlm-viwikifc", | |
"SemViQA/bc-infoxlm-isedsc01", | |
"SemViQA/bc-erniem-viwikifc", | |
"SemViQA/bc-erniem-isedsc01" | |
] | |
) | |
tc_model_name = st.selectbox( | |
"Three-Class Classification Model", | |
[ | |
"SemViQA/tc-xlmr-viwikifc", | |
"SemViQA/tc-xlmr-isedsc01", | |
"SemViQA/tc-infoxlm-viwikifc", | |
"SemViQA/tc-infoxlm-isedsc01", | |
"SemViQA/tc-erniem-viwikifc", | |
"SemViQA/tc-erniem-isedsc01" | |
] | |
) | |
# Threshold settings | |
st.markdown("#### βοΈ Analysis Thresholds") | |
tfidf_threshold = st.slider( | |
"Confidence Threshold", | |
0.0, 1.0, 0.5, | |
help="Adjust sensitivity in evidence search" | |
) | |
length_ratio_threshold = st.slider( | |
"Length Ratio Threshold", | |
0.1, 1.0, 0.5, | |
help="Adjust maximum evidence length" | |
) | |
# Display settings | |
st.markdown("#### ποΈ Display") | |
show_details = st.checkbox( | |
"Show Probability Details", | |
value=False, | |
help="Display detailed probability information" | |
) | |
# Performance settings | |
st.markdown("#### β‘ Performance") | |
num_threads = st.slider( | |
"CPU Threads", | |
1, psutil.cpu_count(), | |
psutil.cpu_count(logical=False), | |
help="Adjust processing performance" | |
) | |
os.environ["OMP_NUM_THREADS"] = str(num_threads) | |
os.environ["MKL_NUM_THREADS"] = str(num_threads) | |
# Main content | |
tabs = st.tabs(["π Verify", "π History", "βΉοΈ Info"]) | |
tokenizer_qatc, model_qatc = load_model(qatc_model_name, QATCForQuestionAnswering, device=DEVICE) | |
tokenizer_bc, model_bc = load_model(bc_model_name, ClaimModelForClassification, is_bc=True, device=DEVICE) | |
tokenizer_tc, model_tc = load_model(tc_model_name, ClaimModelForClassification, device=DEVICE) | |
verdict_icons = { | |
"SUPPORTED": "β ", | |
"REFUTED": "β", | |
"NEI": "β οΈ" | |
} | |
# --- Tab Verify --- | |
with tabs[0]: | |
col1, col2 = st.columns([2, 1]) | |
with col1: | |
st.markdown("### π Input Information") | |
claim = st.text_area( | |
"Claim to Verify", | |
"ChiαΊΏn tranh vα»i Campuchia ΔΓ£ kαΊΏt thΓΊc trΖ°α»c khi Viα»t Nam thα»ng nhαΊ₯t.", | |
help="Enter the claim to be verified" | |
) | |
context = st.text_area( | |
"Context", | |
"Sau khi thα»ng nhαΊ₯t, Viα»t Nam tiαΊΏp tα»₯c gαΊ·p khΓ³ khΔn do sα»± sα»₯p Δα» vΓ tan rΓ£ cα»§a Δα»ng minh LiΓͺn XΓ΄ cΓΉng Khα»i phΓa ΔΓ΄ng, cΓ‘c lα»nh cαΊ₯m vαΊn cα»§a Hoa Kα»³, chiαΊΏn tranh vα»i Campuchia, biΓͺn giα»i giΓ‘p Trung Quα»c vΓ hαΊu quαΊ£ cα»§a chΓnh sΓ‘ch bao cαΊ₯p sau nhiα»u nΔm Γ‘p dα»₯ng. NΔm 1986, ΔαΊ£ng Cα»ng sαΊ£n ban hΓ nh cαΊ£i cΓ‘ch Δα»i mα»i, tαΊ‘o Δiα»u kiα»n hΓ¬nh thΓ nh kinh tαΊΏ thα» trΖ°α»ng vΓ hα»i nhαΊp sΓ’u rα»ng. CαΊ£i cΓ‘ch Δα»i mα»i kαΊΏt hợp cΓΉng quy mΓ΄ dΓ’n sα» lα»n ΔΖ°a Viα»t Nam trα» thΓ nh mα»t trong nhα»―ng nΖ°α»c Δang phΓ‘t triα»n cΓ³ tα»c Δα» tΔng trΖ°α»ng thuα»c nhΓ³m nhanh nhαΊ₯t thαΊΏ giα»i, Δược coi lΓ Hα» mα»i chΓ’u Γ dΓΉ cho vαΊ«n gαΊ·p phαΊ£i nhα»―ng thΓ‘ch thα»©c nhΖ° tham nhΕ©ng, tα»i phαΊ‘m gia tΔng, Γ΄ nhiα» m mΓ΄i trΖ°α»ng vΓ phΓΊc lợi xΓ£ hα»i chΖ°a ΔαΊ§y Δα»§. NgoΓ i ra, giα»i bαΊ₯t Δα»ng chΓnh kiαΊΏn, chΓnh phα»§ mα»t sα» nΖ°α»c phΖ°Ζ‘ng TΓ’y vΓ cΓ‘c tα» chα»©c theo dΓ΅i nhΓ’n quyα»n cΓ³ quan Δiα»m chα» trΓch hα» sΖ‘ nhΓ’n quyα»n cα»§a Viα»t Nam liΓͺn quan ΔαΊΏn cΓ‘c vαΊ₯n Δα» tΓ΄n giΓ‘o, kiα»m duyα»t truyα»n thΓ΄ng, hαΊ‘n chαΊΏ hoαΊ‘t Δα»ng α»§ng hα» nhΓ’n quyα»n cΓΉng cΓ‘c quyα»n tα»± do dΓ’n sα»±.", | |
help="Enter context or reference text" | |
) | |
verify_button = st.button("π Verify", use_container_width=True) | |
with col2: | |
st.markdown("### π Results") | |
if verify_button: | |
with st.spinner("Verifying..."): | |
# Preprocess texts | |
preprocessed_claim = preprocess_text(claim) | |
preprocessed_context = preprocess_text(context) | |
# Clear memory and perform verification | |
gc.collect() | |
if DEVICE == "cuda": | |
torch.cuda.empty_cache() | |
start_time = time.time() | |
# Monitor initial performance | |
initial_perf = monitor_performance() | |
result = perform_verification( | |
preprocessed_claim, preprocessed_context, | |
model_qatc, tokenizer_qatc, | |
model_tc, tokenizer_tc, | |
model_bc, tokenizer_bc, | |
tfidf_threshold, length_ratio_threshold | |
) | |
total_time = time.time() - start_time | |
# Monitor final performance | |
final_perf = monitor_performance() | |
# Format details | |
details = "" | |
if show_details: | |
gpu_memory_used = f"{float(final_perf.get('gpu_memory_used', 0)):.2f} MB" if DEVICE == "cuda" else "N/A" | |
gpu_memory_cached = f"{float(final_perf.get('gpu_memory_cached', 0)):.2f} MB" if DEVICE == "cuda" else "N/A" | |
details = f""" | |
Details: | |
- 3-Class Probability: {result['prob3class']:.2f} | |
- 3-Class Predicted Label: {['NEI', 'SUPPORTED', 'REFUTED'][result['pred_tc']]} | |
- 2-Class Probability: {result['prob2class']:.2f} | |
2-Class Predicted Label: {['SUPPORTED', 'REFUTED'][result['pred_bc']] if isinstance(result['pred_bc'], int) and result['pred_tc'] != 0 else 'Not used'} | |
Performance Metrics: | |
- GPU Memory Used: {gpu_memory_used} | |
- GPU Memory Cached: {gpu_memory_cached} | |
- CPU Usage: {final_perf['cpu_percent']}% | |
- Memory Usage: {final_perf['memory_percent']}% | |
""" | |
# Store result with performance metrics | |
st.session_state.latest_result = { | |
"claim": claim, | |
"evidence": result['evidence'], | |
"verdict": result['verdict'], | |
"evidence_time": result['evidence_time'], | |
"verdict_time": result['verdict_time'], | |
"total_time": total_time, | |
"details": details, | |
"qatc_model": qatc_model_name, | |
"bc_model": bc_model_name, | |
"tc_model": tc_model_name, | |
"performance_metrics": final_perf | |
} | |
# Add to history | |
if 'history' not in st.session_state: | |
st.session_state.history = [] | |
st.session_state.history.append(st.session_state.latest_result) | |
# Display result with performance metrics | |
res = st.session_state.latest_result | |
verdict_class = { | |
"SUPPORTED": "verdict-supported", | |
"REFUTED": "verdict-refuted", | |
"NEI": "verdict-nei" | |
}.get(res['verdict'], "") | |
gpu_memory_text = ( | |
f"<li>GPU Memory: {float(res['performance_metrics'].get('gpu_memory_used', 0)):.2f} MB</li>" | |
if DEVICE == "cuda" | |
else "<li>GPU Memory: N/A</li>" | |
) | |
st.markdown(f""" | |
<div class="result-box"> | |
<h3>Verification Results</h3> | |
<p><strong>Claim:</strong> {res['claim']}</p> | |
<p><strong>Evidence:</strong> {res['evidence']}</p> | |
<p class="verdict {verdict_class}"> | |
{verdict_icons.get(res['verdict'], '')} {res['verdict']} | |
</p> | |
<div class="stats-box"> | |
<p><strong>Evidence Extraction Time:</strong> {res['evidence_time']:.2f} seconds</p> | |
<p><strong>Classification Time:</strong> {res['verdict_time']:.2f} seconds</p> | |
<p><strong>Total Time:</strong> {res['total_time']:.2f} seconds</p> | |
<p><strong>Performance:</strong></p> | |
<ul> | |
<li>CPU: {res['performance_metrics']['cpu_percent']}%</li> | |
<li>RAM: {res['performance_metrics']['memory_percent']}%</li> | |
{gpu_memory_text} | |
</ul> | |
</div> | |
{f"<div class='code-block'><pre>{res['details']}</pre></div>" if show_details else ""} | |
</div> | |
""", unsafe_allow_html=True) | |
# Download button with performance metrics | |
result_text = f""" | |
Claim: {res['claim']} | |
Evidence: {res['evidence']} | |
Verdict: {res['verdict']} | |
Details: {res['details']} | |
Performance: | |
- CPU: {res['performance_metrics']['cpu_percent']}% | |
- RAM: {res['performance_metrics']['memory_percent']}% | |
- GPU Memory: {f"{float(res['performance_metrics'].get('gpu_memory_used', 0)):.2f} MB" if DEVICE == "cuda" else "N/A"} | |
""" | |
st.download_button( | |
"π₯ Download Results", | |
data=result_text, | |
file_name="verification_results.txt", | |
mime="text/plain" | |
) | |
else: | |
st.info("Please enter information and click Verify to begin.") | |
# --- Tab History --- | |
with tabs[1]: | |
st.markdown("### π Verification History") | |
if 'history' in st.session_state and st.session_state.history: | |
# Download full history | |
history_df = pd.DataFrame(st.session_state.history) | |
st.download_button( | |
"π₯ Download Full History", | |
data=history_df.to_csv(index=False).encode('utf-8'), | |
file_name="verification_history.csv", | |
mime="text/csv" | |
) | |
# Display history | |
for idx, record in enumerate(reversed(st.session_state.history), 1): | |
st.markdown(f""" | |
<div class="result-box"> | |
<h4>Verification #{idx}</h4> | |
<p><strong>Claim:</strong> {record['claim']}</p> | |
<p><strong>Verdict:</strong> {verdict_icons.get(record['verdict'], '')} {record['verdict']}</p> | |
<p><strong>Time:</strong> {record['total_time']:.2f} seconds</p> | |
</div> | |
""", unsafe_allow_html=True) | |
else: | |
st.info("No verification history available.") | |
# --- Tab Info --- | |
with tabs[2]: | |
st.markdown(""" | |
### βΉοΈ About SemViQA | |
**Author:** [**Nam V. Nguyen**](https://github.com/DAVID-NGUYEN-S16), [**Dien X. Tran**](https://github.com/xndien2004), Thanh T. Tran, Anh T. Hoang, Tai V. Duong, Di T. Le, Phuc-Lu Le | |
SemViQA is a cutting-edge Vietnamese fact-checking system designed to combat misinformation. It leverages semantic-based evidence retrieval (SER) and a two-step verdict classification (TVC) approach to verify claims efficiently. By combining TF-IDF with a Question Answering Token Classifier (QATC), SemViQA improves accuracy while reducing inference time. Achieving state-of-the-art performance, it has set new benchmarks on ViWikiFC (80.82% accuracy) and ISE-DSC01 (78.97% accuracy) datasets. With its 7x speed boost, SemViQA is a powerful tool for ensuring information integrity in the Vietnamese language. | |
#### π How to Use | |
1. Enter the claim to verify | |
2. Enter context or reference text | |
3. Adjust parameters in Settings if needed | |
4. Click Verify button | |
#### βοΈ Parameters | |
- **Confidence Threshold:** Adjust sensitivity in evidence search | |
- **Length Ratio Threshold:** An important parameter in the evidence retrieval process. It determines how text segments are processed when compared to the length of the claim to be verified. | |
- **CPU Threads:** Adjust processing performance | |
#### π Results | |
- **SUPPORTED:** The claim is supported by evidence | |
- **REFUTED:** The claim is refuted by evidence | |
- **NEI:** Not enough information to conclude | |
""") |