Spaces:
Sleeping
Sleeping
import streamlit as st | |
import os | |
import requests | |
import faiss | |
import numpy as np | |
from pdfminer.high_level import extract_text | |
from sentence_transformers import SentenceTransformer | |
from langdetect import detect | |
# Load the Hugging Face token | |
HUGGINGFACE_TOKEN = os.environ.get("Key2") | |
HF_MODEL = "HuggingFaceH4/zephyr-7b-alpha" | |
# Load Sentence Transformer Model | |
EMBEDDER = SentenceTransformer("all-MiniLM-L6-v2") | |
# Default system prompts | |
SYSTEM_PROMPTS = { | |
"Multi-Query": "Generate five alternative versions of the user question: {question}", | |
"RAG Fusion": "Combine multiple queries into a single, refined query: {question}", | |
"Decomposition": "Break down the user question into simpler sub-questions: {question}", | |
"Step Back": "Refine the user question by asking a more general question: {question}", | |
"HyDE": "Generate a hypothetical document relevant to the user question: {question}", | |
} | |
# Helper function to interact with Hugging Face API | |
def query_hf(prompt, max_new_tokens=1000, temperature=0.7, top_k=50): | |
headers = {"Authorization": f"Bearer {HUGGINGFACE_TOKEN}"} | |
payload = { | |
"inputs": prompt, | |
"parameters": { | |
"max_new_tokens": max_new_tokens, | |
"temperature": temperature, | |
"top_k": top_k, | |
}, | |
} | |
response = requests.post(f"https://api-inference.huggingface.co/models/{HF_MODEL}", headers=headers, json=payload) | |
if response.status_code == 200: | |
return response.json()[0]["generated_text"] | |
st.error(f"Error: {response.status_code} - {response.text}") | |
# Extract text from PDF | |
def extract_pdf_text(pdf_file): | |
return extract_text(pdf_file).split("\n") | |
# Chunk text into segments | |
def chunk_text(text_lines, chunk_size=500): | |
words = " ".join(text_lines).split() | |
return [" ".join(words[i:i + chunk_size]) for i in range(0, len(words), chunk_size)] | |
# Build FAISS Index | |
def build_index(embeddings): | |
dimension = embeddings.shape[1] | |
index = faiss.IndexFlatL2(dimension) | |
index.add(embeddings) | |
return index | |
# Search FAISS Index | |
def search_index(query_embedding, index, top_k=5): | |
distances, indices = index.search(query_embedding, top_k) | |
return indices[0] | |
# Embed PDF content and build FAISS index | |
def process_pdf(pdf_file): | |
text_lines = extract_pdf_text(pdf_file) | |
chunks = chunk_text(text_lines) | |
embeddings = EMBEDDER.encode(chunks, convert_to_tensor=False) | |
faiss_index = build_index(np.array(embeddings)) | |
return chunks, faiss_index | |
# Generate query translations | |
def translate_query(prompt, method, max_new_tokens, temperature, top_k): | |
formatted_prompt = SYSTEM_PROMPTS[method].format(question=prompt) | |
return query_hf(formatted_prompt, max_new_tokens, temperature, top_k).split("\n") | |
# Retrieve relevant chunks from FAISS index | |
def retrieve_chunks(translated_queries, faiss_index, chunks, top_k=5): | |
relevant_chunks = [] | |
for query in translated_queries: | |
query_embedding = EMBEDDER.encode([query], convert_to_tensor=False) | |
indices = search_index(np.array(query_embedding), faiss_index, top_k) | |
relevant_chunks.extend([chunks[i] for i in indices]) | |
return relevant_chunks | |
# Generate final response using RAG approach | |
def generate_final_response(prompt, context, max_new_tokens, temperature, top_k): | |
input_text = f"Context: {context}\n\nAnswer this question: {prompt}" | |
return query_hf(input_text, max_new_tokens, temperature, top_k) | |
# Streamlit UI | |
def main(): | |
st.title("Enhanced RAG Model with FAISS Indexing") | |
# Sidebar Inputs | |
pdf_file = st.sidebar.file_uploader("Upload PDF", type="pdf") | |
query_translation = st.sidebar.selectbox("Query Translation Method", list(SYSTEM_PROMPTS.keys())) | |
similarity_method = st.sidebar.selectbox("Similarity Search Method", ["Cosine Similarity", "KNN"]) | |
k_value = st.sidebar.slider("K Value (for KNN)", 1, 10, 5) if similarity_method == "KNN" else 5 | |
max_new_tokens = st.sidebar.slider("Max New Tokens", 10, 1000, 500) | |
temperature = st.sidebar.slider("Temperature", 0.1, 1.0, 0.7) | |
top_k = st.sidebar.slider("Top K", 1, 100, 50) | |
# Input Prompt | |
prompt = st.text_input("Enter your query:") | |
# State Management | |
if 'chunks' not in st.session_state: | |
st.session_state.chunks = [] | |
if 'faiss_index' not in st.session_state: | |
st.session_state.faiss_index = None | |
# Step 1: Process PDF | |
if st.button("1. Embed PDF") and pdf_file: | |
st.session_state.chunks, st.session_state.faiss_index = process_pdf(pdf_file) | |
st.success("PDF Embedded Successfully") | |
# Step 2: Generate Translated Queries | |
if st.button("2. Query Translation") and prompt: | |
st.session_state.translated_queries = translate_query(prompt, query_translation, max_new_tokens, temperature, top_k) | |
st.write("**Generated Queries:**", st.session_state.translated_queries) | |
# Step 3: Retrieve Relevant Chunks | |
if st.button("3. Retrieve Documents") and st.session_state.translated_queries: | |
st.session_state.relevant_chunks = retrieve_chunks(st.session_state.translated_queries, st.session_state.faiss_index, st.session_state.chunks, top_k=k_value) | |
st.write("**Retrieved Chunks:**", st.session_state.relevant_chunks) | |
# Step 4: Generate Final Response | |
if st.button("4. Generate Final Response") and st.session_state.relevant_chunks: | |
context = "\n".join(st.session_state.relevant_chunks) | |
final_response = generate_final_response(prompt, context, max_new_tokens, temperature, top_k) | |
st.subheader("Final Response:") | |
st.write(final_response) | |
if __name__ == "__main__": | |
main() | |