RagLMM / app.py
Tamil Eniyan
mod app
1acd638
raw
history blame
9.86 kB
import streamlit as st
import faiss
import numpy as np
import pickle
import json
from sentence_transformers import SentenceTransformer
from transformers import (
pipeline,
RagTokenizer,
RagRetriever,
RagSequenceForGeneration,
)
import torch
# ========================
# File Names & Model Names
# ========================
INDEX_FILE = "faiss_index.index"
CHUNKS_FILE = "chunks.pkl"
CURATED_QA_FILE = "curated_qa_pairs.json"
EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
QA_MODEL_NAME = "deepset/roberta-large-squad2" # For the standard QA pipeline
# ========================
# Loading Functions (cached)
# ========================
@st.cache_resource
def load_index_and_chunks():
try:
index = faiss.read_index(INDEX_FILE)
with open(CHUNKS_FILE, "rb") as f:
chunks = pickle.load(f)
return index, chunks
except Exception as e:
st.error(f"Error loading FAISS index and chunks: {e}")
return None, None
@st.cache_resource
def load_embedding_model():
try:
model = SentenceTransformer(EMBEDDING_MODEL_NAME)
return model
except Exception as e:
st.error(f"Error loading embedding model: {e}")
return None
@st.cache_resource
def load_qa_pipeline():
try:
qa_pipe = pipeline("question-answering", model=QA_MODEL_NAME, tokenizer=QA_MODEL_NAME)
return qa_pipe
except Exception as e:
st.error(f"Error loading QA pipeline: {e}")
return None
@st.cache_resource
def load_curated_qa_pairs(json_file=CURATED_QA_FILE):
try:
with open(json_file, "r", encoding="utf-8") as f:
curated_qa_pairs = json.load(f)
return curated_qa_pairs
except Exception as e:
st.error(f"Error loading curated Q/A pairs from JSON: {e}")
return []
# ========================================
# Standard: Retrieve Curated Q/A Pair Function
# ========================================
def get_curated_pair(query, curated_qa, embed_model, threshold=1.0):
try:
curated_questions = [qa["question"] for qa in curated_qa]
query_embedding = embed_model.encode([query]).astype("float32")
curated_embeddings = embed_model.encode(curated_questions, show_progress_bar=False)
curated_embeddings = np.array(curated_embeddings).astype("float32")
# Build a temporary FAISS index for the curated questions
dimension = curated_embeddings.shape[1]
curated_index = faiss.IndexFlatL2(dimension)
curated_index.add(curated_embeddings)
k = 1
distances, indices = curated_index.search(query_embedding, k)
if distances[0][0] < threshold:
idx = indices[0][0]
return curated_qa[idx]
except Exception as e:
st.error(f"Error retrieving curated Q/A pair: {e}")
return None
# ============================================================
# Custom RAG Retriever: Uses your FAISS index & PDF passages
# ============================================================
class CustomRagRetriever(RagRetriever):
"""
A custom retriever that uses your FAISS index and passages.
It encodes the query with the provided embedding model,
searches your FAISS index, and returns the top retrieved documents.
"""
def __init__(self, config, faiss_index, passages, embed_model, tokenizer, n_docs=5):
self.faiss_index = faiss_index # Your custom FAISS index of PDF embeddings
self.passages = passages # List of PDF passage texts
self.embed_model = embed_model # Embedding model used for encoding queries
self.n_docs = n_docs # Number of top documents to retrieve
self.tokenizer = tokenizer # Save tokenizer for internal use if needed
# Override init_retrieval to bypass loading default passages.
self.init_retrieval = lambda: None
# Call the parent constructor with the required arguments.
super().__init__(config, question_encoder_tokenizer=tokenizer, generator_tokenizer=tokenizer)
def retrieve(self, query, n_docs=None):
try:
if n_docs is None:
n_docs = self.n_docs
# Encode the query using the embedding model
query_embedding = self.embed_model.encode([query]).astype("float32")
distances, indices = self.faiss_index.search(query_embedding, n_docs)
retrieved_docs = [self.passages[i] for i in indices[0]]
return {
"doc_ids": indices,
"doc_scores": distances,
"retrieved_docs": retrieved_docs,
}
except Exception as e:
st.error(f"Error in custom retrieval: {e}")
return {"doc_ids": None, "doc_scores": None, "retrieved_docs": []}
# ============================================================
# Load RAG Model with Custom Retriever (cached for performance)
# ============================================================
@st.cache_resource
def load_rag_model(_faiss_index, passages, _embed_model):
try:
tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
rag_model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq")
custom_retriever = CustomRagRetriever(
config=rag_model.config,
faiss_index=_faiss_index,
passages=passages,
embed_model=_embed_model,
tokenizer=tokenizer,
n_docs=5
)
rag_model.set_retriever(custom_retriever)
return tokenizer, rag_model
except Exception as e:
st.error(f"Error loading RAG model with custom retriever: {e}")
return None, None
def generate_rag_answer(query, tokenizer, rag_model):
try:
inputs = tokenizer(query, return_tensors="pt")
with torch.no_grad():
generated_ids = rag_model.generate(**inputs)
answer = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return answer
except Exception as e:
st.error(f"Error generating answer with RAG model: {e}")
return ""
# ========================================
# Main Streamlit App
# ========================================
def main():
st.title("Takalama")
if "conversation_history" not in st.session_state:
st.session_state.conversation_history = ""
with st.spinner("Loading index and passages..."):
index, chunks = load_index_and_chunks()
if index is None or chunks is None:
return
with st.spinner("Loading embedding model..."):
embed_model = load_embedding_model()
if embed_model is None:
return
with st.spinner("Loading QA pipeline..."):
qa_pipeline = load_qa_pipeline()
if qa_pipeline is None:
return
with st.spinner("Loading curated Q/A pairs..."):
curated_qa_pairs = load_curated_qa_pairs()
st.write("Enter your question about the PDF document:")
query = st.text_input("Question:")
if query:
st.session_state.conversation_history += f"User: {query}\n"
with st.spinner("Retrieving relevant PDF context..."):
try:
query_embedding = embed_model.encode([query]).astype("float32")
k = 3 # Retrieve top 3 chunks
distances, indices = index.search(query_embedding, k)
pdf_context = ""
for idx in indices[0]:
pdf_context += chunks[idx] + "\n"
except Exception as e:
st.error(f"Error retrieving PDF context: {e}")
return
base_context = st.session_state.conversation_history + "\n"
# --- Option 1: Use RAG Model with Custom Retriever ---
if st.button("Use RAG Model with Custom Retriever"):
with st.spinner("Generating answer using RAG model..."):
tokenizer_rag, rag_model = load_rag_model(index, chunks, embed_model)
if tokenizer_rag is None or rag_model is None:
return
rag_answer = generate_rag_answer(query, tokenizer_rag, rag_model)
st.write("**RAG Model Answer:**")
st.write(rag_answer)
st.session_state.conversation_history += f"AI (RAG): {rag_answer}\n"
return
# --- Option 2: Use Standard QA Pipeline with Curated Q/A Pairs ---
with st.spinner("Checking for curated Q/A pair..."):
curated_pair = get_curated_pair(query, curated_qa_pairs, embed_model)
if curated_pair:
st.info("A curated Q/A pair was found and will be used for the answer by default.")
use_full_data = st.checkbox("High Reasoning", value=False)
if not use_full_data:
answer = curated_pair["answer"]
st.write(answer)
st.session_state.conversation_history += f"AI: {answer}\n"
return
else:
context_to_use = base_context + pdf_context
else:
context_to_use = base_context + pdf_context
with st.expander("Show Full PDF Context"):
st.write(pdf_context)
st.subheader("Answer:")
with st.spinner("Generating answer using standard QA pipeline..."):
try:
result = qa_pipeline(question=query, context=context_to_use)
answer = result["answer"]
st.write(answer)
st.session_state.conversation_history += f"AI: {answer}\n"
except Exception as e:
st.error(f"Error generating answer using QA pipeline: {e}")
if __name__ == "__main__":
main()