semviqa-demo / app.py
xuandin's picture
Update app.py
4fdfda4 verified
raw
history blame
12.8 kB
import streamlit as st
import torch
from transformers import AutoTokenizer
from semviqa.ser.qatc_model import QATCForQuestionAnswering
from semviqa.tvc.model import ClaimModelForClassification
from semviqa.ser.ser_eval import extract_evidence_tfidf_qatc
from semviqa.tvc.tvc_eval import classify_claim
import time # Thêm thư viện time để đo thời gian inference
# Load models with caching
@st.cache_resource()
def load_model(model_name, model_class, is_bc=False):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = model_class.from_pretrained(model_name, num_labels=3 if not is_bc else 2)
model.eval()
return tokenizer, model
# Set up page configuration
st.set_page_config(page_title="SemViQA Demo", layout="wide")
# Custom CSS: fixed header and tabs, dynamic height, result box formatting
st.markdown(
"""
<style>
html, body {
height: 100%;
margin: 0;
overflow: hidden;
}
.main-container {
height: calc(100vh - 55px); /* Browser height - 55px */
overflow-y: auto;
padding: 20px;
}
.big-title {
font-size: 36px;
font-weight: bold;
color: #4A90E2;
text-align: center;
margin-top: 20px;
position: sticky; /* Pin the header */
top: 0;
background-color: white; /* Ensure the header covers content when scrolling */
z-index: 100; /* Ensure it's above other content */
}
.sub-title {
font-size: 20px;
color: #666;
text-align: center;
margin-bottom: 20px;
}
.stButton>button {
background-color: #4CAF50;
color: white;
font-size: 16px;
width: 100%;
border-radius: 8px;
padding: 10px;
}
.stTextArea textarea {
font-size: 16px;
min-height: 120px;
}
.result-box {
background-color: #f9f9f9;
padding: 20px;
border-radius: 10px;
box-shadow: 0px 4px 8px rgba(0, 0, 0, 0.1);
margin-top: 20px;
}
.verdict {
font-size: 24px;
font-weight: bold;
margin: 0;
display: flex;
align-items: center;
}
.verdict-icon {
margin-right: 10px;
}
/* Fix the tabs at the top */
div[data-baseweb="tab-list"] {
position: sticky;
top: 55px; /* Height of the header */
background-color: white;
z-index: 99;
}
.stSidebar .sidebar-content {
background-color: #f0f2f6;
padding: 20px;
border-radius: 10px;
}
.stSidebar .st-expander {
background-color: #ffffff;
border-radius: 10px;
padding: 10px;
margin-bottom: 10px;
}
.stSidebar .stSlider {
margin-bottom: 20px;
}
.stSidebar .stSelectbox {
margin-bottom: 20px;
}
.stSidebar .stCheckbox {
margin-bottom: 20px;
}
</style>
""",
unsafe_allow_html=True,
)
# Container for the whole content with dynamic height
with st.container():
st.markdown("<p class='big-title'>SemViQA: Vietnamese Semantic QA for Fact Verification</p>", unsafe_allow_html=True)
st.markdown("<p class='sub-title'>Enter the claim and context to verify its accuracy</p>", unsafe_allow_html=True)
# Sidebar: Global Settings
with st.sidebar.expander("⚙️ Settings", expanded=True):
tfidf_threshold = st.slider("TF-IDF Threshold", 0.0, 1.0, 0.5, 0.01)
length_ratio_threshold = st.slider("Length Ratio Threshold", 0.1, 1.0, 0.5, 0.01)
qatc_model_name = st.selectbox("QATC Model", [
"SemViQA/qatc-infoxlm-viwikifc",
"SemViQA/qatc-infoxlm-isedsc01",
"SemViQA/qatc-vimrc-viwikifc",
"SemViQA/qatc-vimrc-isedsc01"
])
bc_model_name = st.selectbox("Binary Classification Model", [
"SemViQA/bc-xlmr-viwikifc",
"SemViQA/bc-xlmr-isedsc01",
"SemViQA/bc-infoxlm-viwikifc",
"SemViQA/bc-infoxlm-isedsc01",
"SemViQA/bc-erniem-viwikifc",
"SemViQA/bc-erniem-isedsc01"
])
tc_model_name = st.selectbox("3-Class Classification Model", [
"SemViQA/tc-xlmr-viwikifc",
"SemViQA/tc-xlmr-isedsc01",
"SemViQA/tc-infoxlm-viwikifc",
"SemViQA/tc-infoxlm-isedsc01",
"SemViQA/tc-erniem-viwikifc",
"SemViQA/tc-erniem-isedsc01"
])
show_details = st.checkbox("Show Probability Details", value=False)
# Store verification history
if 'history' not in st.session_state:
st.session_state.history = []
if 'latest_result' not in st.session_state:
st.session_state.latest_result = None
# Load the selected models
tokenizer_qatc, model_qatc = load_model(qatc_model_name, QATCForQuestionAnswering)
tokenizer_bc, model_bc = load_model(bc_model_name, ClaimModelForClassification, is_bc=True)
tokenizer_tc, model_tc = load_model(tc_model_name, ClaimModelForClassification)
# Icons for results
verdict_icons = {
"SUPPORTED": "✅",
"REFUTED": "❌",
"NEI": "⚠️"
}
# Tabs: Verify, History, About
tabs = st.tabs(["Verify", "History", "About"])
# --- Tab Verify ---
with tabs[0]:
st.subheader("Verify a Claim")
# 2-column layout: input on the left, results on the right
col_input, col_result = st.columns([2, 1])
with col_input:
claim = st.text_area("Enter Claim", "Vietnam is a country in Southeast Asia.")
context = st.text_area("Enter Context", "Vietnam is a country located in Southeast Asia, covering an area of over 331,000 km² with a population of more than 98 million people.")
verify_button = st.button("Verify", key="verify_button")
with col_result:
st.markdown("<h3>Verification Result</h3>", unsafe_allow_html=True)
if verify_button:
# Placeholder for displaying result/loading
with st.spinner("Verifying..."): # Thêm spinner khi đang xử lý
start_time = time.time() # Bắt đầu đo thời gian inference
with torch.no_grad():
# Extract evidence
evidence_start_time = time.time()
evidence = extract_evidence_tfidf_qatc(
claim, context, model_qatc, tokenizer_qatc,
"cuda" if torch.cuda.is_available() else "cpu",
confidence_threshold=tfidf_threshold,
length_ratio_threshold=length_ratio_threshold
)
evidence_time = time.time() - evidence_start_time
# Hiển thị evidence trước
st.markdown(f"""
<div class='result-box'>
<p><strong>Evidence:</strong> {evidence}</p>
<p><strong>Evidence Inference Time:</strong> {evidence_time:.2f} seconds</p>
</div>
""", unsafe_allow_html=True)
# Classify the claim
verdict_start_time = time.time()
verdict = "NEI"
details = ""
prob3class, pred_tc = classify_claim(
claim, evidence, model_tc, tokenizer_tc,
"cuda" if torch.cuda.is_available() else "cpu"
)
if pred_tc != 0:
prob2class, pred_bc = classify_claim(
claim, evidence, model_bc, tokenizer_bc,
"cuda" if torch.cuda.is_available() else "cpu"
)
if pred_bc == 0:
verdict = "SUPPORTED"
elif prob2class > prob3class:
verdict = "REFUTED"
else:
verdict = ["NEI", "SUPPORTED", "REFUTED"][pred_tc]
if show_details:
details = f"""
<p><strong>3-Class Probability:</strong> {prob3class.item():.2f}</p>
<p><strong>3-Class Predicted Label:</strong> {['NEI', 'SUPPORTED', 'REFUTED'][pred_tc]}</p>
<p><strong>2-Class Probability:</strong> {prob2class.item():.2f}</p>
<p><strong>2-Class Predicted Label:</strong> {['SUPPORTED', 'REFUTED'][pred_bc]}</p>
"""
verdict_time = time.time() - verdict_start_time
# Store verification history and the latest result
st.session_state.history.append({
"claim": claim,
"evidence": evidence,
"verdict": verdict,
"evidence_time": evidence_time,
"verdict_time": verdict_time,
"details": details
})
st.session_state.latest_result = {
"claim": claim,
"evidence": evidence,
"verdict": verdict,
"evidence_time": evidence_time,
"verdict_time": verdict_time,
"details": details
}
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Display the result after verification
res = st.session_state.latest_result
st.markdown(f"""
<div class='result-box'>
<p><strong>Claim:</strong> {res['claim']}</p>
<p><strong>Evidence:</strong> {res['evidence']}</p>
<p><strong>Evidence Inference Time:</strong> {res['evidence_time']:.2f} seconds</p>
<p><strong>Verdict Inference Time:</strong> {res['verdict_time']:.2f} seconds</p>
<p class='verdict'><span class='verdict-icon'>{verdict_icons.get(res['verdict'], '')}</span>{res['verdict']}</p>
{res['details']}
</div>
""", unsafe_allow_html=True)
# Download Verification Result Feature
result_text = f"Claim: {res['claim']}\nEvidence: {res['evidence']}\nVerdict: {res['verdict']}\nDetails: {res['details']}"
st.download_button("Download Result", data=result_text, file_name="verification_result.txt", mime="text/plain")
else:
st.info("No verification result yet.")
# --- Tab History ---
with tabs[1]:
st.subheader("Verification History")
if st.session_state.history:
for idx, record in enumerate(reversed(st.session_state.history), 1):
st.markdown(f"**{idx}. Claim:** {record['claim']} \n**Result:** {verdict_icons.get(record['verdict'], '')} {record['verdict']}")
else:
st.write("No verification history yet.")
# --- Tab About ---
with tabs[2]:
st.subheader("About")
st.markdown("""
<p align="center">
<a href="https://arxiv.org/abs/2503.00955">
<img src="https://img.shields.io/badge/arXiv-2411.00918-red?style=flat&label=arXiv">
</a>
<a href="https://huggingface.co/SemViQA">
<img src="https://img.shields.io/badge/Hugging%20Face-Model-yellow?style=flat">
</a>
<a href="https://pypi.org/project/SemViQA">
<img src="https://img.shields.io/pypi/v/SemViQA?color=blue&label=PyPI">
</a>
<a href="https://github.com/DAVID-NGUYEN-S16/SemViQA">
<img src="https://img.shields.io/github/stars/DAVID-NGUYEN-S16/SemViQA?style=social">
</a>
</p>
""", unsafe_allow_html=True)
st.markdown("""
**Description:**
SemViQA is a Semantic QA system designed for fact verification in Vietnamese.
The system extracts evidence from the provided context and classifies claims as **SUPPORTED**, **REFUTED**, or **NEI** (Not Enough Information) using advanced models.
""")