import streamlit as st import faiss import numpy as np import pickle import json from sentence_transformers import SentenceTransformer from transformers import ( pipeline, RagTokenizer, RagRetriever, RagSequenceForGeneration, ) import torch # ======================== # File Names & Model Names # ======================== INDEX_FILE = "faiss_index.index" CHUNKS_FILE = "chunks.pkl" CURATED_QA_FILE = "curated_qa_pairs.json" EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2" QA_MODEL_NAME = "deepset/roberta-large-squad2" # For the standard QA pipeline # ======================== # Loading Functions (cached) # ======================== @st.cache_resource def load_index_and_chunks(): try: index = faiss.read_index(INDEX_FILE) with open(CHUNKS_FILE, "rb") as f: chunks = pickle.load(f) return index, chunks except Exception as e: st.error(f"Error loading FAISS index and chunks: {e}") return None, None @st.cache_resource def load_embedding_model(): try: model = SentenceTransformer(EMBEDDING_MODEL_NAME) return model except Exception as e: st.error(f"Error loading embedding model: {e}") return None @st.cache_resource def load_qa_pipeline(): try: qa_pipe = pipeline("question-answering", model=QA_MODEL_NAME, tokenizer=QA_MODEL_NAME) return qa_pipe except Exception as e: st.error(f"Error loading QA pipeline: {e}") return None @st.cache_resource def load_curated_qa_pairs(json_file=CURATED_QA_FILE): try: with open(json_file, "r", encoding="utf-8") as f: curated_qa_pairs = json.load(f) return curated_qa_pairs except Exception as e: st.error(f"Error loading curated Q/A pairs from JSON: {e}") return [] # ======================================== # Standard: Retrieve Curated Q/A Pair Function # ======================================== def get_curated_pair(query, curated_qa, embed_model, threshold=1.0): try: curated_questions = [qa["question"] for qa in curated_qa] query_embedding = embed_model.encode([query]).astype("float32") curated_embeddings = embed_model.encode(curated_questions, show_progress_bar=False) curated_embeddings = np.array(curated_embeddings).astype("float32") # Build a temporary FAISS index for the curated questions dimension = curated_embeddings.shape[1] curated_index = faiss.IndexFlatL2(dimension) curated_index.add(curated_embeddings) k = 1 distances, indices = curated_index.search(query_embedding, k) if distances[0][0] < threshold: idx = indices[0][0] return curated_qa[idx] except Exception as e: st.error(f"Error retrieving curated Q/A pair: {e}") return None # ============================================================ # Custom RAG Retriever: Uses your FAISS index & PDF passages # ============================================================ class CustomRagRetriever(RagRetriever): """ A custom retriever that uses your FAISS index and passages. It encodes the query with the provided embedding model, searches your FAISS index, and returns the top retrieved documents. """ def __init__(self, config, faiss_index, passages, embed_model, tokenizer, n_docs=5): self.faiss_index = faiss_index # Your custom FAISS index of PDF embeddings self.passages = passages # List of PDF passage texts self.embed_model = embed_model # Embedding model used for encoding queries self.n_docs = n_docs # Number of top documents to retrieve self.tokenizer = tokenizer # Save tokenizer for internal use if needed # Override init_retrieval to bypass loading default passages. self.init_retrieval = lambda: None # Call the parent constructor with the required arguments. super().__init__(config, question_encoder_tokenizer=tokenizer, generator_tokenizer=tokenizer) def retrieve(self, query, n_docs=None): try: if n_docs is None: n_docs = self.n_docs # Encode the query using the embedding model query_embedding = self.embed_model.encode([query]).astype("float32") distances, indices = self.faiss_index.search(query_embedding, n_docs) retrieved_docs = [self.passages[i] for i in indices[0]] return { "doc_ids": indices, "doc_scores": distances, "retrieved_docs": retrieved_docs, } except Exception as e: st.error(f"Error in custom retrieval: {e}") return {"doc_ids": None, "doc_scores": None, "retrieved_docs": []} # ============================================================ # Load RAG Model with Custom Retriever (cached for performance) # ============================================================ @st.cache_resource def load_rag_model(_faiss_index, passages, _embed_model): try: tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq") rag_model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq") custom_retriever = CustomRagRetriever( config=rag_model.config, faiss_index=_faiss_index, passages=passages, embed_model=_embed_model, tokenizer=tokenizer, n_docs=5 ) rag_model.set_retriever(custom_retriever) return tokenizer, rag_model except Exception as e: st.error(f"Error loading RAG model with custom retriever: {e}") return None, None def generate_rag_answer(query, tokenizer, rag_model): try: inputs = tokenizer(query, return_tensors="pt") with torch.no_grad(): generated_ids = rag_model.generate(**inputs) answer = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] return answer except Exception as e: st.error(f"Error generating answer with RAG model: {e}") return "" # ======================================== # Main Streamlit App # ======================================== def main(): st.title("Takalama") if "conversation_history" not in st.session_state: st.session_state.conversation_history = "" with st.spinner("Loading index and passages..."): index, chunks = load_index_and_chunks() if index is None or chunks is None: return with st.spinner("Loading embedding model..."): embed_model = load_embedding_model() if embed_model is None: return with st.spinner("Loading QA pipeline..."): qa_pipeline = load_qa_pipeline() if qa_pipeline is None: return with st.spinner("Loading curated Q/A pairs..."): curated_qa_pairs = load_curated_qa_pairs() st.write("Enter your question about the PDF document:") query = st.text_input("Question:") if query: st.session_state.conversation_history += f"User: {query}\n" with st.spinner("Retrieving relevant PDF context..."): try: query_embedding = embed_model.encode([query]).astype("float32") k = 3 # Retrieve top 3 chunks distances, indices = index.search(query_embedding, k) pdf_context = "" for idx in indices[0]: pdf_context += chunks[idx] + "\n" except Exception as e: st.error(f"Error retrieving PDF context: {e}") return base_context = st.session_state.conversation_history + "\n" # --- Option 1: Use RAG Model with Custom Retriever --- if st.button("Use RAG Model with Custom Retriever"): with st.spinner("Generating answer using RAG model..."): tokenizer_rag, rag_model = load_rag_model(index, chunks, embed_model) if tokenizer_rag is None or rag_model is None: return rag_answer = generate_rag_answer(query, tokenizer_rag, rag_model) st.write("**RAG Model Answer:**") st.write(rag_answer) st.session_state.conversation_history += f"AI (RAG): {rag_answer}\n" return # --- Option 2: Use Standard QA Pipeline with Curated Q/A Pairs --- with st.spinner("Checking for curated Q/A pair..."): curated_pair = get_curated_pair(query, curated_qa_pairs, embed_model) if curated_pair: st.info("A curated Q/A pair was found and will be used for the answer by default.") use_full_data = st.checkbox("High Reasoning", value=False) if not use_full_data: answer = curated_pair["answer"] st.write(answer) st.session_state.conversation_history += f"AI: {answer}\n" return else: context_to_use = base_context + pdf_context else: context_to_use = base_context + pdf_context with st.expander("Show Full PDF Context"): st.write(pdf_context) st.subheader("Answer:") with st.spinner("Generating answer using standard QA pipeline..."): try: result = qa_pipeline(question=query, context=context_to_use) answer = result["answer"] st.write(answer) st.session_state.conversation_history += f"AI: {answer}\n" except Exception as e: st.error(f"Error generating answer using QA pipeline: {e}") if __name__ == "__main__": main()