semviqa-demo / app.py
xuandin's picture
Update app.py
fb62d04 verified
raw
history blame
11.3 kB
import streamlit as st
import torch
from transformers import AutoTokenizer
from semviqa.ser.qatc_model import QATCForQuestionAnswering
from semviqa.tvc.model import ClaimModelForClassification
from semviqa.ser.ser_eval import extract_evidence_tfidf_qatc
from semviqa.tvc.tvc_eval import classify_claim
import time
import pandas as pd
# Load models with caching
@st.cache_resource()
def load_model(model_name, model_class, is_bc=False):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = model_class.from_pretrained(model_name, num_labels=3 if not is_bc else 2)
model.eval()
return tokenizer, model
# Set up page configuration
st.set_page_config(page_title="SemViQA Demo", layout="wide")
# Custom CSS and JavaScript to make the sidebar sticky
st.markdown(
"""
<style>
/* Fix the sidebar */
.stSidebar {
position: fixed;
top: 0;
height: 100vh;
overflow-y: auto;
z-index: 1000;
}
/* Adjust main content to avoid overlap with the fixed sidebar */
.main .block-container {
margin-left: 25rem; /* Adjust this value based on your sidebar width */
}
</style>
<script>
// JavaScript to make the sidebar sticky
window.addEventListener('scroll', function() {
const sidebar = document.querySelector('.stSidebar');
if (sidebar) {
sidebar.style.top = `${window.scrollY}px`;
}
});
</script>
""",
unsafe_allow_html=True,
)
# Container for the whole content with dynamic height
with st.container():
st.markdown("<p class='big-title'>SemViQA: A Semantic Question Answering System for Vietnamese Information Fact-Checking</p>", unsafe_allow_html=True)
st.markdown("<p class='sub-title'>Enter the claim and context to verify its accuracy</p>", unsafe_allow_html=True)
# Sidebar: Global Settings
with st.sidebar.expander("⚙️ Settings", expanded=True):
tfidf_threshold = st.slider("TF-IDF Threshold", 0.0, 1.0, 0.5, 0.01)
length_ratio_threshold = st.slider("Length Ratio Threshold", 0.1, 1.0, 0.5, 0.01)
qatc_model_name = st.selectbox("QATC Model", [
"SemViQA/qatc-infoxlm-viwikifc",
"SemViQA/qatc-infoxlm-isedsc01",
"SemViQA/qatc-vimrc-viwikifc",
"SemViQA/qatc-vimrc-isedsc01"
])
bc_model_name = st.selectbox("Binary Classification Model", [
"SemViQA/bc-xlmr-viwikifc",
"SemViQA/bc-xlmr-isedsc01",
"SemViQA/bc-infoxlm-viwikifc",
"SemViQA/bc-infoxlm-isedsc01",
"SemViQA/bc-erniem-viwikifc",
"SemViQA/bc-erniem-isedsc01"
])
tc_model_name = st.selectbox("3-Class Classification Model", [
"SemViQA/tc-xlmr-viwikifc",
"SemViQA/tc-xlmr-isedsc01",
"SemViQA/tc-infoxlm-viwikifc",
"SemViQA/tc-infoxlm-isedsc01",
"SemViQA/tc-erniem-viwikifc",
"SemViQA/tc-erniem-isedsc01"
])
show_details = st.checkbox("Show Probability Details", value=False)
# Store verification history
if 'history' not in st.session_state:
st.session_state.history = []
if 'latest_result' not in st.session_state:
st.session_state.latest_result = None
# Load the selected models
tokenizer_qatc, model_qatc = load_model(qatc_model_name, QATCForQuestionAnswering)
tokenizer_bc, model_bc = load_model(bc_model_name, ClaimModelForClassification, is_bc=True)
tokenizer_tc, model_tc = load_model(tc_model_name, ClaimModelForClassification)
# Icons for results
verdict_icons = {
"SUPPORTED": "✅",
"REFUTED": "❌",
"NEI": "⚠️"
}
# Tabs: Verify, History, About
tabs = st.tabs(["Verify", "History", "About"])
# --- Tab Verify ---
with tabs[0]:
st.subheader("Verify a Claim")
# 2-column layout: input on the left, results on the right
col_input, col_result = st.columns([2, 1])
with col_input:
claim = st.text_area("Enter Claim", "Vietnam is a country in Southeast Asia.")
context = st.text_area("Enter Context", "Vietnam is a country located in Southeast Asia, covering an area of over 331,000 km² with a population of more than 98 million people.")
verify_button = st.button("Verify", key="verify_button")
with col_result:
st.markdown("<h3>Verification Result</h3>", unsafe_allow_html=True)
if verify_button:
# Placeholder for displaying result/loading
with st.spinner("Verifying..."): # Thêm spinner khi đang xử lý
start_time = time.time() # Bắt đầu đo thời gian inference
with torch.no_grad():
# Extract evidence
evidence_start_time = time.time()
evidence = extract_evidence_tfidf_qatc(
claim, context, model_qatc, tokenizer_qatc,
"cuda" if torch.cuda.is_available() else "cpu",
confidence_threshold=tfidf_threshold,
length_ratio_threshold=length_ratio_threshold
)
evidence_time = time.time() - evidence_start_time
# Classify the claim
verdict_start_time = time.time()
verdict = "NEI"
details = ""
prob3class, pred_tc = classify_claim(
claim, evidence, model_tc, tokenizer_tc,
"cuda" if torch.cuda.is_available() else "cpu"
)
if pred_tc != 0:
prob2class, pred_bc = classify_claim(
claim, evidence, model_bc, tokenizer_bc,
"cuda" if torch.cuda.is_available() else "cpu"
)
if pred_bc == 0:
verdict = "SUPPORTED"
elif prob2class > prob3class:
verdict = "REFUTED"
else:
verdict = ["NEI", "SUPPORTED", "REFUTED"][pred_tc]
if show_details:
details = f"""
<p><strong>3-Class Probability:</strong> {prob3class.item():.2f}</p>
<p><strong>3-Class Predicted Label:</strong> {['NEI', 'SUPPORTED', 'REFUTED'][pred_tc]}</p>
<p><strong>2-Class Probability:</strong> {prob2class.item():.2f}</p>
<p><strong>2-Class Predicted Label:</strong> {['SUPPORTED', 'REFUTED'][pred_bc]}</p>
"""
verdict_time = time.time() - verdict_start_time
# Store verification history and the latest result
st.session_state.history.append({
"claim": claim,
"evidence": evidence,
"verdict": verdict,
"evidence_time": evidence_time,
"verdict_time": verdict_time,
"details": details
})
st.session_state.latest_result = {
"claim": claim,
"evidence": evidence,
"verdict": verdict,
"evidence_time": evidence_time,
"verdict_time": verdict_time,
"details": details
}
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Display the result after verification
res = st.session_state.latest_result
st.markdown(f"""
<div class='result-box'>
<p><strong>Claim:</strong> {res['claim']}</p>
<p><strong>Evidence:</strong> {res['evidence']}</p>
<p><strong>Evidence Inference Time:</strong> {res['evidence_time']:.2f} seconds</p>
<p><strong>Verdict Inference Time:</strong> {res['verdict_time']:.2f} seconds</p>
<p class='verdict'><span class='verdict-icon'>{verdict_icons.get(res['verdict'], '')}</span>{res['verdict']}</p>
{res['details']}
</div>
""", unsafe_allow_html=True)
# Download Verification Result Feature
result_text = f"Claim: {res['claim']}\nEvidence: {res['evidence']}\nVerdict: {res['verdict']}\nDetails: {res['details']}"
st.download_button("Download Result", data=result_text, file_name="verification_result.txt", mime="text/plain")
else:
st.info("No verification result yet.")
# --- Tab History ---
with tabs[1]:
st.subheader("Verification History")
if st.session_state.history:
# Convert history to DataFrame for easy download
history_df = pd.DataFrame(st.session_state.history)
st.download_button(
label="Download Full History",
data=history_df.to_csv(index=False).encode('utf-8'),
file_name="verification_history.csv",
mime="text/csv",
)
for idx, record in enumerate(reversed(st.session_state.history), 1):
st.markdown(f"**{idx}. Claim:** {record['claim']} \n**Result:** {verdict_icons.get(record['verdict'], '')} {record['verdict']}")
else:
st.write("No verification history yet.")
# --- Tab About ---
with tabs[2]:
st.subheader("About")
st.markdown("""
<p align="center">
<a href="https://arxiv.org/abs/2503.00955">
<img src="https://img.shields.io/badge/arXiv-2411.00918-red?style=flat&label=arXiv">
</a>
<a href="https://huggingface.co/SemViQA">
<img src="https://img.shields.io/badge/Hugging%20Face-Model-yellow?style=flat">
</a>
<a href="https://pypi.org/project/SemViQA">
<img src="https://img.shields.io/pypi/v/SemViQA?color=blue&label=PyPI">
</a>
<a href="https://github.com/DAVID-NGUYEN-S16/SemViQA">
<img src="https://img.shields.io/github/stars/DAVID-NGUYEN-S16/SemViQA?style=social">
</a>
</p>
""", unsafe_allow_html=True)
st.markdown("""
**Description:**
SemViQA is a Semantic QA system designed for fact verification in Vietnamese.
The system extracts evidence from the provided context and classifies claims as **SUPPORTED**, **REFUTED**, or **NEI** (Not Enough Information) using advanced models.
""")