File size: 2,255 Bytes
ae479fd 076c725 ae479fd 076c725 114c773 076c725 20c970d 076c725 20c970d 076c725 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import streamlit as st
import faiss
import numpy as np
import pickle
from sentence_transformers import SentenceTransformer
from transformers import pipeline
# File names for saved data
INDEX_FILE = "faiss_index.index"
CHUNKS_FILE = "chunks.pkl"
EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
QA_MODEL_NAME = "deepset/roberta-large-squad2" # You can change this to any Hugging Face QA model
@st.cache_resource
def load_index_and_chunks():
index = faiss.read_index(INDEX_FILE)
with open(CHUNKS_FILE, "rb") as f:
chunks = pickle.load(f)
return index, chunks
@st.cache_resource
def load_embedding_model():
return SentenceTransformer(EMBEDDING_MODEL_NAME)
@st.cache_resource
def load_qa_pipeline():
# This QA pipeline expects a question and a context
return pipeline("question-answering", model=QA_MODEL_NAME, tokenizer=QA_MODEL_NAME)
def main():
st.title("PDF Question-Answering App")
# Load FAISS index, chunks, and models
index, chunks = load_index_and_chunks()
embed_model = load_embedding_model()
qa_pipeline = load_qa_pipeline()
st.write("Enter your question about the PDF document:")
query = st.text_input("Question:")
if query:
# Encode the query using the same SentenceTransformer model
query_embedding = embed_model.encode([query]).astype('float32')
# Retrieve top k relevant chunks
k = 3
distances, indices = index.search(query_embedding, k)
# Prepare combined context from the retrieved chunks
context = ""
for idx in indices[0]:
context_piece = chunks[idx]
context += context_piece + " "
# Use an expander to optionally display the retrieved context
with st.expander("Show Retrieved Context"):
for idx in indices[0]:
st.write(chunks[idx])
st.subheader("Answer:")
try:
# Use the QA pipeline to generate an answer based on the combined context
result = qa_pipeline(question=query, context=context)
st.write(result["answer"])
except Exception as e:
st.error(f"Error generating answer: {e}")
if __name__ == "__main__":
main()
|