RagLMM / app.py
Tamil Eniyan
Updated app to use fire base
1bfc32c
import streamlit as st
import faiss
import numpy as np
import pickle
import json
import torch
from sentence_transformers import SentenceTransformer
from transformers import pipeline, RagTokenizer, RagRetriever, RagSequenceForGeneration
# ========================
# File Names & Model Names
# ========================
INDEX_FILE = "faiss_index.index"
CHUNKS_FILE = "chunks.pkl"
CURATED_QA_FILE = "curated_qa_pairs.json"
EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
QA_MODEL_NAME = "deepset/roberta-large-squad2"
# ========================
# Loading Functions (cached)
# ========================
@st.cache_resource
def load_index_and_chunks():
try:
index = faiss.read_index(INDEX_FILE)
with open(CHUNKS_FILE, "rb") as f:
chunks = pickle.load(f)
return index, chunks
except Exception as e:
st.error(f"Error loading FAISS index and chunks: {e}")
return None, None
@st.cache_resource
def load_embedding_model():
return SentenceTransformer(EMBEDDING_MODEL_NAME)
@st.cache_resource
def load_qa_pipeline():
return pipeline("question-answering", model=QA_MODEL_NAME, tokenizer=QA_MODEL_NAME)
@st.cache_resource
def load_curated_qa_pairs():
try:
with open(CURATED_QA_FILE, "r", encoding="utf-8") as f:
return json.load(f)
except Exception as e:
st.error(f"Error loading curated Q/A pairs: {e}")
return []
# ========================================
# Chatbot Interface & Conversation Handling
# ========================================
def display_conversation():
"""Displays conversation history in a structured chat format."""
for role, message in st.session_state.conversation_history:
with st.chat_message(role):
st.write(message)
def add_to_conversation(role, message):
"""Adds a message to conversation history."""
st.session_state.conversation_history.append((role, message))
# Initialize conversation history
if "conversation_history" not in st.session_state:
st.session_state.conversation_history = []
# ========================================
# Main Streamlit Chat UI
# ========================================
st.title("Takalama - AI Chat")
# Load models & data
index, chunks = load_index_and_chunks()
embed_model = load_embedding_model()
qa_pipeline = load_qa_pipeline()
curated_qa_pairs = load_curated_qa_pairs()
display_conversation()
# User Input
user_query = st.chat_input("Ask a question about the document...")
if user_query:
add_to_conversation("user", user_query)
# Check for curated Q/A pair
answer = None
for pair in curated_qa_pairs:
if user_query.lower() in pair["question"].lower():
answer = pair["answer"]
break
if not answer:
# Retrieve relevant context
query_embedding = embed_model.encode([user_query]).astype("float32")
distances, indices = index.search(query_embedding, 3)
pdf_context = "\n".join(chunks[idx] for idx in indices[0])
# Generate an answer using the QA pipeline
response = qa_pipeline(question=user_query, context=pdf_context)
answer = response.get("answer", "I couldn't find an answer to that.")
add_to_conversation("assistant", answer)
st.rerun()