AllAboutRAG / app.py
bainskarman's picture
Update app.py
28202fc verified
import streamlit as st
import os
import requests
import faiss
import numpy as np
from pdfminer.high_level import extract_text
from sentence_transformers import SentenceTransformer
from langdetect import detect
# Load the Hugging Face token
HUGGINGFACE_TOKEN = os.environ.get("Key2")
HF_MODEL = "HuggingFaceH4/zephyr-7b-alpha"
# Load Sentence Transformer Model
EMBEDDER = SentenceTransformer("all-MiniLM-L6-v2")
# Default system prompts
SYSTEM_PROMPTS = {
"Multi-Query": "Generate five alternative versions of the user question: {question}",
"RAG Fusion": "Combine multiple queries into a single, refined query: {question}",
"Decomposition": "Break down the user question into simpler sub-questions: {question}",
"Step Back": "Refine the user question by asking a more general question: {question}",
"HyDE": "Generate a hypothetical document relevant to the user question: {question}",
}
# Helper function to interact with Hugging Face API
def query_hf(prompt, max_new_tokens=1000, temperature=0.7, top_k=50):
headers = {"Authorization": f"Bearer {HUGGINGFACE_TOKEN}"}
payload = {
"inputs": prompt,
"parameters": {
"max_new_tokens": max_new_tokens,
"temperature": temperature,
"top_k": top_k,
},
}
response = requests.post(f"https://api-inference.huggingface.co/models/{HF_MODEL}", headers=headers, json=payload)
if response.status_code == 200:
return response.json()[0]["generated_text"]
st.error(f"Error: {response.status_code} - {response.text}")
# Extract text from PDF
def extract_pdf_text(pdf_file):
return extract_text(pdf_file).split("\n")
# Chunk text into segments
def chunk_text(text_lines, chunk_size=500):
words = " ".join(text_lines).split()
return [" ".join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)]
# Build FAISS Index
def build_index(embeddings):
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)
return index
# Search FAISS Index
def search_index(query_embedding, index, top_k=5):
distances, indices = index.search(query_embedding, top_k)
return indices[0]
# Embed PDF content and build FAISS index
def process_pdf(pdf_file):
text_lines = extract_pdf_text(pdf_file)
chunks = chunk_text(text_lines)
embeddings = EMBEDDER.encode(chunks, convert_to_tensor=False)
faiss_index = build_index(np.array(embeddings))
return chunks, faiss_index
# Generate query translations
def translate_query(prompt, method, max_new_tokens, temperature, top_k):
formatted_prompt = SYSTEM_PROMPTS[method].format(question=prompt)
return query_hf(formatted_prompt, max_new_tokens, temperature, top_k).split("\n")
# Retrieve relevant chunks from FAISS index
def retrieve_chunks(translated_queries, faiss_index, chunks, top_k=5):
relevant_chunks = []
for query in translated_queries:
query_embedding = EMBEDDER.encode([query], convert_to_tensor=False)
indices = search_index(np.array(query_embedding), faiss_index, top_k)
relevant_chunks.extend([chunks[i] for i in indices])
return relevant_chunks
# Generate final response using RAG approach
def generate_final_response(prompt, context, max_new_tokens, temperature, top_k):
input_text = f"Context: {context}\n\nAnswer this question: {prompt}"
return query_hf(input_text, max_new_tokens, temperature, top_k)
# Streamlit UI
def main():
st.title("Enhanced RAG Model with FAISS Indexing")
# Sidebar Inputs
pdf_file = st.sidebar.file_uploader("Upload PDF", type="pdf")
query_translation = st.sidebar.selectbox("Query Translation Method", list(SYSTEM_PROMPTS.keys()))
similarity_method = st.sidebar.selectbox("Similarity Search Method", ["Cosine Similarity", "KNN"])
k_value = st.sidebar.slider("K Value (for KNN)", 1, 10, 5) if similarity_method == "KNN" else 5
max_new_tokens = st.sidebar.slider("Max New Tokens", 10, 1000, 500)
temperature = st.sidebar.slider("Temperature", 0.1, 1.0, 0.7)
top_k = st.sidebar.slider("Top K", 1, 100, 50)
# Input Prompt
prompt = st.text_input("Enter your query:")
# State Management
if 'chunks' not in st.session_state:
st.session_state.chunks = []
if 'faiss_index' not in st.session_state:
st.session_state.faiss_index = None
# Step 1: Process PDF
if st.button("1. Embed PDF") and pdf_file:
st.session_state.chunks, st.session_state.faiss_index = process_pdf(pdf_file)
st.success("PDF Embedded Successfully")
# Step 2: Generate Translated Queries
if st.button("2. Query Translation") and prompt:
st.session_state.translated_queries = translate_query(prompt, query_translation, max_new_tokens, temperature, top_k)
st.write("**Generated Queries:**", st.session_state.translated_queries)
# Step 3: Retrieve Relevant Chunks
if st.button("3. Retrieve Documents") and st.session_state.translated_queries:
st.session_state.relevant_chunks = retrieve_chunks(st.session_state.translated_queries, st.session_state.faiss_index, st.session_state.chunks, top_k=k_value)
st.write("**Retrieved Chunks:**", st.session_state.relevant_chunks)
# Step 4: Generate Final Response
if st.button("4. Generate Final Response") and st.session_state.relevant_chunks:
context = "\n".join(st.session_state.relevant_chunks)
final_response = generate_final_response(prompt, context, max_new_tokens, temperature, top_k)
st.subheader("Final Response:")
st.write(final_response)
if __name__ == "__main__":
main()