Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering | |
st.set_page_config(page_title="Question Answering Tool", page_icon=":mag_right:") | |
def load_model(): | |
"""Loads the DistilBERT model and tokenizer for QA.""" | |
model = DistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased-distilled-squad") | |
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased-distilled-squad") | |
return model, tokenizer | |
def get_answer(question, text, tokenizer, model): | |
"""Extracts the most relevant answer from the given text.""" | |
if any(phrase in question.lower() for phrase in ["your name", "who are you", "about you"]): | |
return "I am Numini, NativUttarMini, created by Sanju Debnath at University of Calcutta." | |
# Tokenize input text and question | |
inputs = tokenizer(question, text, return_tensors="pt", truncation=True, padding=True, max_length=512) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
start_idx = torch.argmax(outputs.start_logits) | |
end_idx = torch.argmax(outputs.end_logits) + 1 | |
# Validate extracted indices | |
if start_idx >= end_idx or end_idx > inputs.input_ids.shape[1]: | |
return "I couldn't find a clear answer in the given text." | |
# Decode extracted answer | |
answer = tokenizer.decode(inputs.input_ids[0][start_idx:end_idx], skip_special_tokens=True) | |
# Ensure answer is meaningful | |
if len(answer.split()) < 2: | |
return "I'm not sure about the exact answer. Can you try rephrasing the question?" | |
return answer | |
def main(): | |
st.title("π Advanced Question Answering Tool") | |
st.write("Ask a question based on the given text, and I'll extract the best possible answer.") | |
model, tokenizer = load_model() | |
with st.form("qa_form"): | |
text = st.text_area("π Enter the text/document:", height=200) | |
question = st.text_input("β Enter your question:") | |
submit = st.form_submit_button("π Get Answer") | |
if submit: | |
if not text.strip(): | |
st.warning("β οΈ Please enter some text to analyze.") | |
elif not question.strip(): | |
st.warning("β οΈ Please enter a question.") | |
else: | |
with st.spinner("π€ Thinking..."): | |
answer = get_answer(question, text, tokenizer, model) | |
st.success(f"β Answer: {answer}") | |
if __name__ == "__main__": | |
main() | |