import streamlit as st import torch from transformers import AutoTokenizer from semviqa.ser.qatc_model import QATCForQuestionAnswering from semviqa.tvc.model import ClaimModelForClassification from semviqa.ser.ser_eval import extract_evidence_tfidf_qatc from semviqa.tvc.tvc_eval import classify_claim # Load models with caching @st.cache_resource() def load_model(model_name, model_class, is_bc=False): tokenizer = AutoTokenizer.from_pretrained(model_name) model = model_class.from_pretrained(model_name, num_labels=3 if not is_bc else 2) model.eval() return tokenizer, model # Page Configuration st.set_page_config(page_title="SemViQA Demo", layout="wide") # Custom CSS for improved UI st.markdown(""" """, unsafe_allow_html=True) # Page Header st.markdown("

SemViQA: Vietnamese Fact-Checking System

", unsafe_allow_html=True) st.markdown("

Enter a claim and context to verify its accuracy

", unsafe_allow_html=True) # Sidebar: Settings with st.sidebar.expander("⚙️ Settings", expanded=True): tfidf_threshold = st.slider("TF-IDF Threshold", 0.0, 1.0, 0.5, 0.01) length_ratio_threshold = st.slider("Length Ratio Threshold", 0.1, 1.0, 0.5, 0.01) qatc_model_name = st.selectbox("QATC Model", [ "SemViQA/qatc-infoxlm-viwikifc", "SemViQA/qatc-infoxlm-isedsc01", "SemViQA/qatc-vimrc-viwikifc", "SemViQA/qatc-vimrc-isedsc01" ]) bc_model_name = st.selectbox("Binary Classification Model", [ "SemViQA/bc-xlmr-viwikifc", "SemViQA/bc-xlmr-isedsc01", "SemViQA/bc-infoxlm-viwikifc", "SemViQA/bc-infoxlm-isedsc01", "SemViQA/bc-erniem-viwikifc", "SemViQA/bc-erniem-isedsc01" ]) tc_model_name = st.selectbox("Three-Class Classification Model", [ "SemViQA/tc-xlmr-viwikifc", "SemViQA/tc-xlmr-isedsc01", "SemViQA/tc-infoxlm-viwikifc", "SemViQA/tc-infoxlm-isedsc01", "SemViQA/tc-erniem-viwikifc", "SemViQA/tc-erniem-isedsc01" ]) show_details = st.checkbox("Show probability details", value=False) # Load Models tokenizer_qatc, model_qatc = load_model(qatc_model_name, QATCForQuestionAnswering) tokenizer_bc, model_bc = load_model(bc_model_name, ClaimModelForClassification, is_bc=True) tokenizer_tc, model_tc = load_model(tc_model_name, ClaimModelForClassification) # Define verdict icons verdict_icons = { "SUPPORTED": "✅", "REFUTED": "❌", "NEI": "⚠️" } # Tabs for functionalities tabs = st.tabs(["Verify", "History", "About"]) # --- Verify Tab --- with tabs[0]: st.subheader("Verify a Claim") claim = st.text_area("Enter Claim", "Vietnam is a country in Southeast Asia.") context = st.text_area("Enter Context", "Vietnam is a country located in Southeast Asia.") if st.button("Verify", key="verify_button"): with st.spinner("Verifying..."): with torch.no_grad(): evidence = extract_evidence_tfidf_qatc( claim, context, model_qatc, tokenizer_qatc, "cuda" if torch.cuda.is_available() else "cpu", confidence_threshold=tfidf_threshold, length_ratio_threshold=length_ratio_threshold ) verdict = "NEI" prob3class, pred_tc = classify_claim(claim, evidence, model_tc, tokenizer_tc, "cuda" if torch.cuda.is_available() else "cpu") if pred_tc != 0: prob2class, pred_bc = classify_claim(claim, evidence, model_bc, tokenizer_bc, "cuda" if torch.cuda.is_available() else "cpu") verdict = "SUPPORTED" if pred_bc == 0 else "REFUTED" if prob2class > prob3class else ["NEI", "SUPPORTED", "REFUTED"][pred_tc] # Display result st.markdown(f"""

Result

Evidence: {evidence}

{verdict_icons.get(verdict, '')}{verdict}

""", unsafe_allow_html=True) if torch.cuda.is_available(): torch.cuda.empty_cache() # --- About Tab --- with tabs[2]: st.subheader("About SemViQA") st.markdown("""SemViQA is a semantic fact-checking system for Vietnamese information verification.""")