AllAboutRAG / app.py
bainskarman's picture
Update app.py
7b666bb verified
raw
history blame
2.65 kB
import streamlit as st
from PyPDF2 import PdfReader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA
from langchain.llms import HuggingFacePipeline
import torch
from transformers import pipeline
# Load a smaller LLM (e.g., Zephyr-7B or Mistral-7B)
def load_llm():
model_name = "HuggingFaceH4/zephyr-7b-alpha" # Replace with your preferred model
pipe = pipeline("text-generation", model=model_name, torch_dtype=torch.float16, device_map="auto")
llm = HuggingFacePipeline(pipeline=pipe)
return llm
# Extract text from PDF
def extract_text_from_pdf(file):
reader = PdfReader(file)
text = ""
for page in reader.pages:
text += page.extract_text()
return text
# Split text into chunks
def split_text(text, chunk_size=1000, chunk_overlap=200):
splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
chunks = splitter.split_text(text)
return chunks
# Create embeddings and vector store
def create_vector_store(chunks):
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
vector_store = FAISS.from_texts(chunks, embeddings)
return vector_store
# Query the PDF
def query_pdf(vector_store, query, llm):
qa = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=vector_store.as_retriever())
result = qa.run(query)
return result
# Streamlit App
def main():
st.title("Chat with PDF")
st.write("Upload a PDF and ask questions about it!")
# File uploader
uploaded_file = st.file_uploader("Upload a PDF", type="pdf")
if uploaded_file is None:
st.info("Using default PDF.")
uploaded_file = "default.pdf" # Add a default PDF
# Extract text
text = extract_text_from_pdf(uploaded_file)
# Split text into chunks
chunks = split_text(text)
# Create vector store
vector_store = create_vector_store(chunks)
# Load LLM
llm = load_llm()
# Query translation options
query_method = st.selectbox(
"Query Translation Method",
["Multi-Query", "RAG Fusion", "Decomposition", "Step Back", "HyDE"],
help="Choose a method to improve query retrieval."
)
# User input
query = st.text_input("Ask a question about the PDF:")
if query:
# Query the PDF
result = query_pdf(vector_store, query, llm)
st.write("**Answer:**", result["answer"])
st.write("**Source Text:**", result["source_text"])
if __name__ == "__main__":
main()