import streamlit as st import numpy as np import torch from transformers import DistilBertTokenizer, DistilBertForMaskedLM from qa_model import ReuseQuestionDistilBERT @st.cache_resource def load_model(): try: mod = DistilBertForMaskedLM.from_pretrained("distilbert-base-uncased").distilbert m = ReuseQuestionDistilBERT(mod) m.load_state_dict(torch.load("distilbert_reuse.model", map_location=torch.device('cpu'))) model = m tokenizer = DistilBertTokenizer.from_pretrained('qa_tokenizer') return model, tokenizer except Exception as e: st.error(f"Error loading model: {e}") return None, None def get_answer(question, text, tokenizer, model): if model is None or tokenizer is None: return "Model not loaded properly." question = [question.strip()] text = [text.strip()] inputs = tokenizer( question, text, max_length=512, truncation="only_second", padding="max_length", return_tensors="pt" ) with torch.no_grad(): outputs = model( inputs["input_ids"], attention_mask=inputs["attention_mask"], start_positions=None, end_positions=None ) if "start_logits" not in outputs or "end_logits" not in outputs: return "Error: Model output structure is incorrect." start = torch.argmax(outputs["start_logits"], dim=1) end = torch.argmax(outputs["end_logits"], dim=1) ans_tokens = inputs["input_ids"][0, start:end + 1] answer_tokens = tokenizer.convert_ids_to_tokens(ans_tokens, skip_special_tokens=True) predicted = tokenizer.convert_tokens_to_string(answer_tokens) return predicted or "No answer found." def main(): st.set_page_config(page_title="Question Answering Tool", page_icon=":mag_right:") st.write("# Question Answering Tool") model, tokenizer = load_model() with st.form("qa_form"): text = st.text_area("Enter your text here") question = st.text_input("Enter your question here") if st.form_submit_button("Submit"): if not text or not question: st.warning("Please enter both text and a question.") else: st.text("Processing...") answer = get_answer(question, text, tokenizer, model) st.text(f"Answer: {answer}") if __name__ == "__main__": main()