File size: 2,255 Bytes
ae479fd
076c725
 
 
 
 
ae479fd
076c725
 
 
 
114c773
076c725
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20c970d
076c725
 
 
 
20c970d
 
 
 
 
076c725
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import streamlit as st
import faiss
import numpy as np
import pickle
from sentence_transformers import SentenceTransformer
from transformers import pipeline

# File names for saved data
INDEX_FILE = "faiss_index.index"
CHUNKS_FILE = "chunks.pkl"
EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
QA_MODEL_NAME = "deepset/roberta-large-squad2"  # You can change this to any Hugging Face QA model

@st.cache_resource
def load_index_and_chunks():
    index = faiss.read_index(INDEX_FILE)
    with open(CHUNKS_FILE, "rb") as f:
        chunks = pickle.load(f)
    return index, chunks

@st.cache_resource
def load_embedding_model():
    return SentenceTransformer(EMBEDDING_MODEL_NAME)

@st.cache_resource
def load_qa_pipeline():
    # This QA pipeline expects a question and a context
    return pipeline("question-answering", model=QA_MODEL_NAME, tokenizer=QA_MODEL_NAME)

def main():
    st.title("PDF Question-Answering App")
    
    # Load FAISS index, chunks, and models
    index, chunks = load_index_and_chunks()
    embed_model = load_embedding_model()
    qa_pipeline = load_qa_pipeline()
    
    st.write("Enter your question about the PDF document:")
    query = st.text_input("Question:")
    
    if query:
        # Encode the query using the same SentenceTransformer model
        query_embedding = embed_model.encode([query]).astype('float32')
        
        # Retrieve top k relevant chunks
        k = 3
        distances, indices = index.search(query_embedding, k)
        
        # Prepare combined context from the retrieved chunks
        context = ""
        for idx in indices[0]:
            context_piece = chunks[idx]
            context += context_piece + " "
        
        # Use an expander to optionally display the retrieved context
        with st.expander("Show Retrieved Context"):
            for idx in indices[0]:
                st.write(chunks[idx])
        
        st.subheader("Answer:")
        try:
            # Use the QA pipeline to generate an answer based on the combined context
            result = qa_pipeline(question=query, context=context)
            st.write(result["answer"])
        except Exception as e:
            st.error(f"Error generating answer: {e}")

if __name__ == "__main__":
    main()