semviqa-demo / app.py
xuandin's picture
Update app.py
ed692f5 verified
import streamlit as st
import torch
from transformers import AutoTokenizer
from semviqa.ser.qatc_model import QATCForQuestionAnswering
from semviqa.tvc.model import ClaimModelForClassification
from semviqa.ser.ser_eval import extract_evidence_tfidf_qatc
from semviqa.tvc.tvc_eval import classify_claim
import time
import pandas as pd
import os
import psutil
import gc
import numpy as np
from functools import lru_cache
import threading
from concurrent.futures import ThreadPoolExecutor
import torch.nn.functional as F
# Set environment variables to optimize CPU performance
os.environ["OMP_NUM_THREADS"] = str(psutil.cpu_count(logical=False))
os.environ["MKL_NUM_THREADS"] = str(psutil.cpu_count(logical=False))
torch.set_num_threads(psutil.cpu_count(logical=False))
# Set device globally
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Cache for model outputs
@lru_cache(maxsize=1000)
def cached_classify_claim(claim, evidence, model_name, is_bc=False):
tokenizer, model = load_model(model_name, ClaimModelForClassification, is_bc=is_bc, device=DEVICE)
with torch.no_grad():
prob, pred = classify_claim(claim, evidence, model, tokenizer, DEVICE)
return prob, pred
# Optimized model loading with caching
@st.cache_resource(ttl=3600) # Cache for 1 hour
def load_model(model_name, model_class, is_bc=False, device=None):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = model_class.from_pretrained(model_name, num_labels=3 if not is_bc else 2)
model.eval()
# Optimize model for inference
if device == "cuda":
model = model.half() # Use FP16 for faster inference
torch.cuda.empty_cache()
model.to(device)
return tokenizer, model
# Optimized text preprocessing
@st.cache_data(ttl=3600)
def preprocess_text(text):
# Add any text cleaning or normalization here
return text.strip()
# Batch processing for evidence extraction
def batch_extract_evidence(claims, contexts, model_qatc, tokenizer_qatc, batch_size=4):
results = []
for i in range(0, len(claims), batch_size):
batch_claims = claims[i:i + batch_size]
batch_contexts = contexts[i:i + batch_size]
with torch.no_grad():
batch_results = [
extract_evidence_tfidf_qatc(
claim, context, model_qatc, tokenizer_qatc,
DEVICE,
confidence_threshold=0.5,
length_ratio_threshold=0.5
)
for claim, context in zip(batch_claims, batch_contexts)
]
results.extend(batch_results)
return results
# Optimized verification function with parallel processing
def perform_verification(claim, context, model_qatc, tokenizer_qatc, model_tc, tokenizer_tc,
model_bc, tokenizer_bc, tfidf_threshold, length_ratio_threshold):
# Extract evidence with optimized settings
evidence_start_time = time.time()
evidence = extract_evidence_tfidf_qatc(
claim, context, model_qatc, tokenizer_qatc,
DEVICE,
confidence_threshold=tfidf_threshold,
length_ratio_threshold=length_ratio_threshold
)
evidence_time = time.time() - evidence_start_time
# Clear memory after evidence extraction
if DEVICE == "cuda":
torch.cuda.empty_cache()
gc.collect()
verdict_start_time = time.time()
# Parallel classification using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=2) as executor:
future_tc = executor.submit(cached_classify_claim, claim, evidence, tc_model_name, False)
future_bc = executor.submit(cached_classify_claim, claim, evidence, bc_model_name, True)
prob3class, pred_tc = future_tc.result()
prob2class, pred_bc = future_bc.result()
with torch.no_grad():
verdict = "NEI"
if pred_tc != 0:
verdict = "SUPPORTED" if pred_bc == 0 else "REFUTED" if prob2class > prob3class else ["NEI", "SUPPORTED", "REFUTED"][pred_tc]
verdict_time = time.time() - verdict_start_time
return {
"evidence": evidence,
"verdict": verdict,
"evidence_time": evidence_time,
"verdict_time": verdict_time,
"prob3class": prob3class.item() if isinstance(prob3class, torch.Tensor) else prob3class,
"pred_tc": pred_tc,
"prob2class": prob2class.item() if isinstance(prob2class, torch.Tensor) else prob2class,
"pred_bc": pred_bc
}
# Add performance monitoring
def monitor_performance():
if DEVICE == "cuda":
return {
"gpu_memory_used": torch.cuda.memory_allocated() / 1024**2,
"gpu_memory_cached": torch.cuda.memory_reserved() / 1024**2,
"cpu_percent": psutil.cpu_percent(),
"memory_percent": psutil.virtual_memory().percent
}
return {
"cpu_percent": psutil.cpu_percent(),
"memory_percent": psutil.virtual_memory().percent
}
# Set page configuration
st.set_page_config(
page_title="SemViQA - A Semantic Question Answering System for Vietnamese Information Fact-Checking",
page_icon="πŸ”",
layout="wide",
initial_sidebar_state="expanded"
)
# Custom CSS
st.markdown("""
<style>
/* Main theme colors */
:root {
--primary-color: #1f77b4;
--secondary-color: #2c3e50;
--accent-color: #e74c3c;
--background-color: #f8f9fa;
--text-color: #2c3e50;
--border-color: #e0e0e0;
}
/* General styling */
.stApp {
background-color: var(--background-color);
color: var(--text-color);
}
/* Header styling */
.main-header {
background: linear-gradient(135deg, var(--primary-color), var(--secondary-color));
color: white;
padding: 2rem;
border-radius: 10px;
margin-bottom: 2rem;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
.main-title {
font-size: 2.5rem;
font-weight: bold;
text-align: center;
margin-bottom: 1rem;
}
.sub-title {
font-size: 1.2rem;
text-align: center;
opacity: 0.9;
}
/* Input styling */
.stTextArea textarea {
border: 2px solid var(--border-color);
border-radius: 8px;
padding: 1rem;
font-size: 1rem;
min-height: 150px;
background-color: white;
}
/* Button styling */
.stButton>button {
background: linear-gradient(135deg, var(--primary-color), var(--secondary-color));
color: white;
border: none;
border-radius: 8px;
padding: 0.8rem 2rem;
font-size: 1.1rem;
font-weight: bold;
transition: all 0.3s ease;
}
.stButton>button:hover {
transform: translateY(-2px);
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.2);
}
/* Result box styling */
.result-box {
background-color: white;
border-radius: 12px;
padding: 2rem;
margin: 1rem 0;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
/* Info section styling */
.info-section {
background-color: white;
border-radius: 12px;
padding: 2rem;
margin: 1rem 0;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
.info-section h3 {
color: var(--primary-color);
font-size: 1.8rem;
margin-bottom: 1.5rem;
border-bottom: 2px solid var(--border-color);
padding-bottom: 0.5rem;
}
.info-section h4 {
color: var(--secondary-color);
font-size: 1.4rem;
margin: 1.5rem 0 1rem 0;
}
.info-section p {
font-size: 1.1rem;
line-height: 1.6;
color: var(--text-color);
margin-bottom: 1.5rem;
}
.info-section ol, .info-section ul {
margin-left: 1.5rem;
margin-bottom: 1.5rem;
}
.info-section li {
font-size: 1.1rem;
line-height: 1.6;
margin-bottom: 0.5rem;
}
.info-section strong {
color: var(--primary-color);
}
.verdict {
font-size: 1.8rem;
font-weight: bold;
padding: 1rem;
border-radius: 8px;
margin: 1rem 0;
text-align: center;
}
.verdict-supported {
background-color: #d4edda;
color: #155724;
}
.verdict-refuted {
background-color: #f8d7da;
color: #721c24;
}
.verdict-nei {
background-color: #fff3cd;
color: #856404;
}
/* Sidebar styling */
.css-1d391kg {
background-color: white;
padding: 2rem;
border-radius: 12px;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
/* Stats box styling */
.stats-box {
background-color: white;
border-radius: 8px;
padding: 1rem;
margin: 0.5rem 0;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.05);
}
/* Code block styling */
.code-block {
background-color: #f8f9fa;
border: 1px solid var(--border-color);
border-radius: 8px;
padding: 1rem;
font-family: monospace;
margin: 1rem 0;
}
/* Tab styling */
.stTabs [data-baseweb="tab-list"] {
gap: 2rem;
}
.stTabs [data-baseweb="tab"] {
background-color: white;
border-radius: 8px;
padding: 0.5rem 1rem;
margin: 0 0.5rem;
}
.stTabs [aria-selected="true"] {
background-color: var(--primary-color);
color: white;
}
</style>
""", unsafe_allow_html=True)
# Main header
st.markdown("""
<div class="main-header">
<div class="main-title">SemViQA</div>
<div class="sub-title">A Semantic Question Answering System for Vietnamese Information Fact-Checking</div>
</div>
""", unsafe_allow_html=True)
# Sidebar
with st.sidebar:
st.markdown("### βš™οΈ System Settings")
# Model selection
st.markdown("#### 🧠 Model Selection")
qatc_model_name = st.selectbox(
"QATC Model",
[
"SemViQA/qatc-infoxlm-viwikifc",
"SemViQA/qatc-infoxlm-isedsc01",
"SemViQA/qatc-vimrc-viwikifc",
"SemViQA/qatc-vimrc-isedsc01"
]
)
bc_model_name = st.selectbox(
"Binary Classification Model",
[
"SemViQA/bc-xlmr-viwikifc",
"SemViQA/bc-xlmr-isedsc01",
"SemViQA/bc-infoxlm-viwikifc",
"SemViQA/bc-infoxlm-isedsc01",
"SemViQA/bc-erniem-viwikifc",
"SemViQA/bc-erniem-isedsc01"
]
)
tc_model_name = st.selectbox(
"Three-Class Classification Model",
[
"SemViQA/tc-xlmr-viwikifc",
"SemViQA/tc-xlmr-isedsc01",
"SemViQA/tc-infoxlm-viwikifc",
"SemViQA/tc-infoxlm-isedsc01",
"SemViQA/tc-erniem-viwikifc",
"SemViQA/tc-erniem-isedsc01"
]
)
# Threshold settings
st.markdown("#### βš–οΈ Analysis Thresholds")
tfidf_threshold = st.slider(
"Confidence Threshold",
0.0, 1.0, 0.5,
help="Adjust sensitivity in evidence search"
)
length_ratio_threshold = st.slider(
"Length Ratio Threshold",
0.1, 1.0, 0.5,
help="Adjust maximum evidence length"
)
# Display settings
st.markdown("#### πŸ‘οΈ Display")
show_details = st.checkbox(
"Show Probability Details",
value=False,
help="Display detailed probability information"
)
# Performance settings
st.markdown("#### ⚑ Performance")
num_threads = st.slider(
"CPU Threads",
1, psutil.cpu_count(),
psutil.cpu_count(logical=False),
help="Adjust processing performance"
)
os.environ["OMP_NUM_THREADS"] = str(num_threads)
os.environ["MKL_NUM_THREADS"] = str(num_threads)
# Main content
tabs = st.tabs(["πŸ” Verify", "πŸ“Š History", "ℹ️ Info"])
tokenizer_qatc, model_qatc = load_model(qatc_model_name, QATCForQuestionAnswering, device=DEVICE)
tokenizer_bc, model_bc = load_model(bc_model_name, ClaimModelForClassification, is_bc=True, device=DEVICE)
tokenizer_tc, model_tc = load_model(tc_model_name, ClaimModelForClassification, device=DEVICE)
verdict_icons = {
"SUPPORTED": "βœ…",
"REFUTED": "❌",
"NEI": "⚠️"
}
# --- Tab Verify ---
with tabs[0]:
col1, col2 = st.columns([2, 1])
with col1:
st.markdown("### πŸ“ Input Information")
claim = st.text_area(
"Claim to Verify",
"ChiαΊΏn tranh vα»›i Campuchia Δ‘Γ£ kαΊΏt thΓΊc trΖ°α»›c khi Việt Nam thα»‘ng nhαΊ₯t.",
help="Enter the claim to be verified"
)
context = st.text_area(
"Context",
"Sau khi thα»‘ng nhαΊ₯t, Việt Nam tiαΊΏp tα»₯c gαΊ·p khΓ³ khΔƒn do sα»± sα»₯p Δ‘α»• vΓ  tan rΓ£ cα»§a Δ‘α»“ng minh LiΓͺn XΓ΄ cΓΉng Khα»‘i phΓ­a Đông, cΓ‘c lệnh cαΊ₯m vαΊ­n cα»§a Hoa Kα»³, chiαΊΏn tranh vα»›i Campuchia, biΓͺn giα»›i giΓ‘p Trung Quα»‘c vΓ  hαΊ­u quαΊ£ cα»§a chΓ­nh sΓ‘ch bao cαΊ₯p sau nhiều nΔƒm Γ‘p dα»₯ng. NΔƒm 1986, Đảng Cα»™ng sαΊ£n ban hΓ nh cαΊ£i cΓ‘ch Δ‘α»•i mα»›i, tαΊ‘o Δ‘iều kiện hΓ¬nh thΓ nh kinh tαΊΏ thα»‹ trường vΓ  hα»™i nhαΊ­p sΓ’u rα»™ng. CαΊ£i cΓ‘ch Δ‘α»•i mα»›i kαΊΏt hợp cΓΉng quy mΓ΄ dΓ’n sα»‘ lα»›n Δ‘Ζ°a Việt Nam trở thΓ nh mα»™t trong nhα»―ng nΖ°α»›c Δ‘ang phΓ‘t triển cΓ³ tα»‘c Δ‘α»™ tΔƒng trưởng thuα»™c nhΓ³m nhanh nhαΊ₯t thαΊΏ giα»›i, được coi lΓ  Hα»• mα»›i chΓ’u Á dΓΉ cho vαΊ«n gαΊ·p phαΊ£i nhα»―ng thΓ‘ch thα»©c nhΖ° tham nhΕ©ng, tα»™i phαΊ‘m gia tΔƒng, Γ΄ nhiα»…m mΓ΄i trường vΓ  phΓΊc lợi xΓ£ hα»™i chΖ°a Δ‘αΊ§y Δ‘α»§. NgoΓ i ra, giα»›i bαΊ₯t Δ‘α»“ng chΓ­nh kiαΊΏn, chΓ­nh phα»§ mα»™t sα»‘ nΖ°α»›c phΖ°Ζ‘ng TΓ’y vΓ  cΓ‘c tα»• chα»©c theo dΓ΅i nhΓ’n quyền cΓ³ quan Δ‘iểm chỉ trΓ­ch hα»“ sΖ‘ nhΓ’n quyền cα»§a Việt Nam liΓͺn quan Δ‘αΊΏn cΓ‘c vαΊ₯n đề tΓ΄n giΓ‘o, kiểm duyệt truyền thΓ΄ng, hαΊ‘n chαΊΏ hoαΊ‘t Δ‘α»™ng α»§ng hα»™ nhΓ’n quyền cΓΉng cΓ‘c quyền tα»± do dΓ’n sα»±.",
help="Enter context or reference text"
)
verify_button = st.button("πŸ” Verify", use_container_width=True)
with col2:
st.markdown("### πŸ“Š Results")
if verify_button:
with st.spinner("Verifying..."):
# Preprocess texts
preprocessed_claim = preprocess_text(claim)
preprocessed_context = preprocess_text(context)
# Clear memory and perform verification
gc.collect()
if DEVICE == "cuda":
torch.cuda.empty_cache()
start_time = time.time()
# Monitor initial performance
initial_perf = monitor_performance()
result = perform_verification(
preprocessed_claim, preprocessed_context,
model_qatc, tokenizer_qatc,
model_tc, tokenizer_tc,
model_bc, tokenizer_bc,
tfidf_threshold, length_ratio_threshold
)
total_time = time.time() - start_time
# Monitor final performance
final_perf = monitor_performance()
# Format details
details = ""
if show_details:
gpu_memory_used = f"{float(final_perf.get('gpu_memory_used', 0)):.2f} MB" if DEVICE == "cuda" else "N/A"
gpu_memory_cached = f"{float(final_perf.get('gpu_memory_cached', 0)):.2f} MB" if DEVICE == "cuda" else "N/A"
details = f"""
Details:
- 3-Class Probability: {result['prob3class']:.2f}
- 3-Class Predicted Label: {['NEI', 'SUPPORTED', 'REFUTED'][result['pred_tc']]}
- 2-Class Probability: {result['prob2class']:.2f}
2-Class Predicted Label: {['SUPPORTED', 'REFUTED'][result['pred_bc']] if isinstance(result['pred_bc'], int) and result['pred_tc'] != 0 else 'Not used'}
Performance Metrics:
- GPU Memory Used: {gpu_memory_used}
- GPU Memory Cached: {gpu_memory_cached}
- CPU Usage: {final_perf['cpu_percent']}%
- Memory Usage: {final_perf['memory_percent']}%
"""
# Store result with performance metrics
st.session_state.latest_result = {
"claim": claim,
"evidence": result['evidence'],
"verdict": result['verdict'],
"evidence_time": result['evidence_time'],
"verdict_time": result['verdict_time'],
"total_time": total_time,
"details": details,
"qatc_model": qatc_model_name,
"bc_model": bc_model_name,
"tc_model": tc_model_name,
"performance_metrics": final_perf
}
# Add to history
if 'history' not in st.session_state:
st.session_state.history = []
st.session_state.history.append(st.session_state.latest_result)
# Display result with performance metrics
res = st.session_state.latest_result
verdict_class = {
"SUPPORTED": "verdict-supported",
"REFUTED": "verdict-refuted",
"NEI": "verdict-nei"
}.get(res['verdict'], "")
gpu_memory_text = (
f"<li>GPU Memory: {float(res['performance_metrics'].get('gpu_memory_used', 0)):.2f} MB</li>"
if DEVICE == "cuda"
else "<li>GPU Memory: N/A</li>"
)
st.markdown(f"""
<div class="result-box">
<h3>Verification Results</h3>
<p><strong>Claim:</strong> {res['claim']}</p>
<p><strong>Evidence:</strong> {res['evidence']}</p>
<p class="verdict {verdict_class}">
{verdict_icons.get(res['verdict'], '')} {res['verdict']}
</p>
<div class="stats-box">
<p><strong>Evidence Extraction Time:</strong> {res['evidence_time']:.2f} seconds</p>
<p><strong>Classification Time:</strong> {res['verdict_time']:.2f} seconds</p>
<p><strong>Total Time:</strong> {res['total_time']:.2f} seconds</p>
<p><strong>Performance:</strong></p>
<ul>
<li>CPU: {res['performance_metrics']['cpu_percent']}%</li>
<li>RAM: {res['performance_metrics']['memory_percent']}%</li>
{gpu_memory_text}
</ul>
</div>
{f"<div class='code-block'><pre>{res['details']}</pre></div>" if show_details else ""}
</div>
""", unsafe_allow_html=True)
# Download button with performance metrics
result_text = f"""
Claim: {res['claim']}
Evidence: {res['evidence']}
Verdict: {res['verdict']}
Details: {res['details']}
Performance:
- CPU: {res['performance_metrics']['cpu_percent']}%
- RAM: {res['performance_metrics']['memory_percent']}%
- GPU Memory: {f"{float(res['performance_metrics'].get('gpu_memory_used', 0)):.2f} MB" if DEVICE == "cuda" else "N/A"}
"""
st.download_button(
"πŸ“₯ Download Results",
data=result_text,
file_name="verification_results.txt",
mime="text/plain"
)
else:
st.info("Please enter information and click Verify to begin.")
# --- Tab History ---
with tabs[1]:
st.markdown("### πŸ“Š Verification History")
if 'history' in st.session_state and st.session_state.history:
# Download full history
history_df = pd.DataFrame(st.session_state.history)
st.download_button(
"πŸ“₯ Download Full History",
data=history_df.to_csv(index=False).encode('utf-8'),
file_name="verification_history.csv",
mime="text/csv"
)
# Display history
for idx, record in enumerate(reversed(st.session_state.history), 1):
st.markdown(f"""
<div class="result-box">
<h4>Verification #{idx}</h4>
<p><strong>Claim:</strong> {record['claim']}</p>
<p><strong>Verdict:</strong> {verdict_icons.get(record['verdict'], '')} {record['verdict']}</p>
<p><strong>Time:</strong> {record['total_time']:.2f} seconds</p>
</div>
""", unsafe_allow_html=True)
else:
st.info("No verification history available.")
# --- Tab Info ---
with tabs[2]:
st.markdown("""
### ℹ️ About SemViQA
**Author:** [**Nam V. Nguyen**](https://github.com/DAVID-NGUYEN-S16), [**Dien X. Tran**](https://github.com/xndien2004), Thanh T. Tran, Anh T. Hoang, Tai V. Duong, Di T. Le, Phuc-Lu Le
SemViQA is a cutting-edge Vietnamese fact-checking system designed to combat misinformation. It leverages semantic-based evidence retrieval (SER) and a two-step verdict classification (TVC) approach to verify claims efficiently. By combining TF-IDF with a Question Answering Token Classifier (QATC), SemViQA improves accuracy while reducing inference time. Achieving state-of-the-art performance, it has set new benchmarks on ViWikiFC (80.82% accuracy) and ISE-DSC01 (78.97% accuracy) datasets. With its 7x speed boost, SemViQA is a powerful tool for ensuring information integrity in the Vietnamese language.
#### πŸ” How to Use
1. Enter the claim to verify
2. Enter context or reference text
3. Adjust parameters in Settings if needed
4. Click Verify button
#### βš™οΈ Parameters
- **Confidence Threshold:** Adjust sensitivity in evidence search
- **Length Ratio Threshold:** An important parameter in the evidence retrieval process. It determines how text segments are processed when compared to the length of the claim to be verified.
- **CPU Threads:** Adjust processing performance
#### πŸ“Š Results
- **SUPPORTED:** The claim is supported by evidence
- **REFUTED:** The claim is refuted by evidence
- **NEI:** Not enough information to conclude
""")