import streamlit as st import faiss import numpy as np import pickle import json import torch from sentence_transformers import SentenceTransformer from transformers import pipeline, RagTokenizer, RagRetriever, RagSequenceForGeneration # ======================== # File Names & Model Names # ======================== INDEX_FILE = "faiss_index.index" CHUNKS_FILE = "chunks.pkl" CURATED_QA_FILE = "curated_qa_pairs.json" EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2" QA_MODEL_NAME = "deepset/roberta-large-squad2" # ======================== # Loading Functions (cached) # ======================== @st.cache_resource def load_index_and_chunks(): try: index = faiss.read_index(INDEX_FILE) with open(CHUNKS_FILE, "rb") as f: chunks = pickle.load(f) return index, chunks except Exception as e: st.error(f"Error loading FAISS index and chunks: {e}") return None, None @st.cache_resource def load_embedding_model(): return SentenceTransformer(EMBEDDING_MODEL_NAME) @st.cache_resource def load_qa_pipeline(): return pipeline("question-answering", model=QA_MODEL_NAME, tokenizer=QA_MODEL_NAME) @st.cache_resource def load_curated_qa_pairs(): try: with open(CURATED_QA_FILE, "r", encoding="utf-8") as f: return json.load(f) except Exception as e: st.error(f"Error loading curated Q/A pairs: {e}") return [] # ======================================== # Chatbot Interface & Conversation Handling # ======================================== def display_conversation(): """Displays conversation history in a structured chat format.""" for role, message in st.session_state.conversation_history: with st.chat_message(role): st.write(message) def add_to_conversation(role, message): """Adds a message to conversation history.""" st.session_state.conversation_history.append((role, message)) # Initialize conversation history if "conversation_history" not in st.session_state: st.session_state.conversation_history = [] # ======================================== # Main Streamlit Chat UI # ======================================== st.title("Takalama - AI Chat") # Load models & data index, chunks = load_index_and_chunks() embed_model = load_embedding_model() qa_pipeline = load_qa_pipeline() curated_qa_pairs = load_curated_qa_pairs() display_conversation() # User Input user_query = st.chat_input("Ask a question about the document...") if user_query: add_to_conversation("user", user_query) # Check for curated Q/A pair answer = None for pair in curated_qa_pairs: if user_query.lower() in pair["question"].lower(): answer = pair["answer"] break if not answer: # Retrieve relevant context query_embedding = embed_model.encode([user_query]).astype("float32") distances, indices = index.search(query_embedding, 3) pdf_context = "\n".join(chunks[idx] for idx in indices[0]) # Generate an answer using the QA pipeline response = qa_pipeline(question=user_query, context=pdf_context) answer = response.get("answer", "I couldn't find an answer to that.") add_to_conversation("assistant", answer) st.rerun()