Question Answering
File size: 2,458 Bytes
4743e80
 
 
 
 
 
ceee800
4743e80
ceee800
 
 
 
 
 
 
 
 
 
4743e80
 
ceee800
 
 
4743e80
 
 
 
 
 
 
 
 
ceee800
4743e80
 
ceee800
 
 
 
 
 
 
4743e80
ceee800
 
4743e80
ceee800
 
 
 
4743e80
 
ceee800
4743e80
 
 
ceee800
 
4743e80
ceee800
4743e80
ceee800
4743e80
 
 
ceee800
 
4743e80
ceee800
 
 
4743e80
ceee800
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import streamlit as st
import numpy as np
import torch
from transformers import DistilBertTokenizer, DistilBertForMaskedLM
from qa_model import ReuseQuestionDistilBERT

@st.cache_resource
def load_model():
    try:
        mod = DistilBertForMaskedLM.from_pretrained("distilbert-base-uncased").distilbert
        m = ReuseQuestionDistilBERT(mod)
        m.load_state_dict(torch.load("distilbert_reuse.model", map_location=torch.device('cpu')))
        model = m
        tokenizer = DistilBertTokenizer.from_pretrained('qa_tokenizer')
        return model, tokenizer
    except Exception as e:
        st.error(f"Error loading model: {e}")
        return None, None

def get_answer(question, text, tokenizer, model):
    if model is None or tokenizer is None:
        return "Model not loaded properly."

    question = [question.strip()]
    text = [text.strip()]

    inputs = tokenizer(
        question,
        text,
        max_length=512,
        truncation="only_second",
        padding="max_length",
        return_tensors="pt"
    )

    with torch.no_grad():
        outputs = model(
            inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            start_positions=None,
            end_positions=None
        )

    if "start_logits" not in outputs or "end_logits" not in outputs:
        return "Error: Model output structure is incorrect."

    start = torch.argmax(outputs["start_logits"], dim=1)
    end = torch.argmax(outputs["end_logits"], dim=1)

    ans_tokens = inputs["input_ids"][0, start:end + 1]
    answer_tokens = tokenizer.convert_ids_to_tokens(ans_tokens, skip_special_tokens=True)
    predicted = tokenizer.convert_tokens_to_string(answer_tokens)
    return predicted or "No answer found."

def main():
    st.set_page_config(page_title="Question Answering Tool", page_icon=":mag_right:")
    st.write("# Question Answering Tool")
    
    model, tokenizer = load_model()
    
    with st.form("qa_form"):
        text = st.text_area("Enter your text here")
        question = st.text_input("Enter your question here")
        
        if st.form_submit_button("Submit"):
            if not text or not question:
                st.warning("Please enter both text and a question.")
            else:
                st.text("Processing...")
                answer = get_answer(question, text, tokenizer, model)
                st.text(f"Answer: {answer}")

if __name__ == "__main__":
    main()