AllAboutRAG / app.py
bainskarman's picture
Update app.py
c0a164f verified
raw
history blame
5.61 kB
import streamlit as st
import os
from huggingface_hub import InferenceApi
from PyPDF2 import PdfReader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langdetect import detect
# Load the Hugging Face token from environment variables (secrets)
token = os.environ.get("KEY2") # Replace "KEY2" with your secret key name
# Initialize the Hugging Face Inference API
def load_llm():
model_name = "HuggingFaceH4/zephyr-7b-alpha" # Replace with your preferred model
api = InferenceApi(repo_id=model_name, token=token)
return api
# Extract text from PDF
def extract_text_from_pdf(file):
reader = PdfReader(file)
text = ""
for page in reader.pages:
text += page.extract_text()
return text
# Split text into chunks
def split_text(text, chunk_size=1000, chunk_overlap=200):
splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
chunks = splitter.split_text(text)
return chunks
# Create embeddings and vector store
def create_vector_store(chunks, indexing_method="multi-representation", **kwargs):
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")
if indexing_method == "multi-representation":
vector_store = FAISS.from_texts(chunks, embeddings)
elif indexing_method == "raptors":
# Implement RAPTORS logic here (e.g., hierarchical chunking)
vector_store = FAISS.from_texts(chunks, embeddings)
elif indexing_method == "colbert":
# Implement ColBERT logic here (e.g., contextualized embeddings)
vector_store = FAISS.from_texts(chunks, embeddings)
return vector_store
# Query the PDF using the Hugging Face API
def query_pdf(vector_store, query, api, query_method="multi-query", max_new_tokens=200, temperature=0.7, top_k=50):
# Retrieve relevant chunks from the vector store
docs = vector_store.similarity_search(query)
context = " ".join([doc.page_content for doc in docs])
# Create a prompt for the LLM
prompt = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:"
# Query the Hugging Face API
response = api(
inputs=prompt,
parameters={
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_k": top_k,
},
)
return response[0]["generated_text"], docs
# Detect language of the text
def detect_language(text):
try:
return detect(text)
except:
return "en" # Default to English if detection fails
# Streamlit App
def main():
st.title("Chat with PDF")
st.write("Upload a PDF and ask questions about it!")
# File uploader
uploaded_file = st.file_uploader("Upload a PDF", type="pdf")
if uploaded_file is None:
st.info("Using default PDF.")
uploaded_file = "default.pdf" # Add a default PDF
# Step 1: Extract text and split into chunks
if "text" not in st.session_state:
st.session_state.text = None
if "chunks" not in st.session_state:
st.session_state.chunks = None
if st.button("Extract Text and Split into Chunks"):
st.session_state.text = extract_text_from_pdf(uploaded_file)
st.session_state.chunks = split_text(st.session_state.text)
st.success("Text extracted and split into chunks!")
# Step 2: Create vector store
if "vector_store" not in st.session_state:
st.session_state.vector_store = None
if st.session_state.chunks:
st.subheader("Indexing Options")
indexing_method = st.selectbox(
"Indexing Method",
["multi-representation", "raptors", "colbert"],
help="Choose how to index the PDF text."
)
if st.button("Create Vector Store"):
st.session_state.vector_store = create_vector_store(st.session_state.chunks, indexing_method=indexing_method)
st.success("Vector store created!")
# Step 3: Load LLM (Hugging Face API)
if "api" not in st.session_state:
st.session_state.api = None
if st.session_state.vector_store:
st.subheader("LLM Parameters")
temperature = st.slider("Temperature", 0.1, 1.0, 0.7, help="Controls randomness in the output.")
top_k = st.slider("Top-k", 1, 100, 50, help="Limits sampling to the top-k tokens.")
max_new_tokens = st.slider("Max New Tokens", 50, 500, 200, help="Maximum number of tokens to generate.")
if st.button("Load LLM"):
api = load_llm()
st.session_state.api = api
st.success("LLM loaded!")
# Step 4: Query the PDF
if st.session_state.api:
st.subheader("Query Translation Options")
query_method = st.selectbox(
"Query Translation Method",
["multi-query", "rag-fusion", "decomposition", "step-back", "hyde"],
help="Choose a method to improve query retrieval."
)
query = st.text_input("Ask a question about the PDF:")
if query:
answer, source_docs = query_pdf(
st.session_state.vector_store,
query,
st.session_state.api,
query_method=query_method,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
)
st.write("**Answer:**", answer)
st.write("**Source Text:**")
for doc in source_docs:
st.write(doc.page_content)
if __name__ == "__main__":
main()