|
import streamlit as st |
|
import faiss |
|
import numpy as np |
|
import pickle |
|
import json |
|
import torch |
|
from sentence_transformers import SentenceTransformer |
|
from transformers import pipeline, RagTokenizer, RagRetriever, RagSequenceForGeneration |
|
|
|
|
|
|
|
|
|
INDEX_FILE = "faiss_index.index" |
|
CHUNKS_FILE = "chunks.pkl" |
|
CURATED_QA_FILE = "curated_qa_pairs.json" |
|
|
|
EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2" |
|
QA_MODEL_NAME = "deepset/roberta-large-squad2" |
|
|
|
|
|
|
|
|
|
@st.cache_resource |
|
def load_index_and_chunks(): |
|
try: |
|
index = faiss.read_index(INDEX_FILE) |
|
with open(CHUNKS_FILE, "rb") as f: |
|
chunks = pickle.load(f) |
|
return index, chunks |
|
except Exception as e: |
|
st.error(f"Error loading FAISS index and chunks: {e}") |
|
return None, None |
|
|
|
@st.cache_resource |
|
def load_embedding_model(): |
|
return SentenceTransformer(EMBEDDING_MODEL_NAME) |
|
|
|
@st.cache_resource |
|
def load_qa_pipeline(): |
|
return pipeline("question-answering", model=QA_MODEL_NAME, tokenizer=QA_MODEL_NAME) |
|
|
|
@st.cache_resource |
|
def load_curated_qa_pairs(): |
|
try: |
|
with open(CURATED_QA_FILE, "r", encoding="utf-8") as f: |
|
return json.load(f) |
|
except Exception as e: |
|
st.error(f"Error loading curated Q/A pairs: {e}") |
|
return [] |
|
|
|
|
|
|
|
|
|
def display_conversation(): |
|
"""Displays conversation history in a structured chat format.""" |
|
for role, message in st.session_state.conversation_history: |
|
with st.chat_message(role): |
|
st.write(message) |
|
|
|
def add_to_conversation(role, message): |
|
"""Adds a message to conversation history.""" |
|
st.session_state.conversation_history.append((role, message)) |
|
|
|
|
|
if "conversation_history" not in st.session_state: |
|
st.session_state.conversation_history = [] |
|
|
|
|
|
|
|
|
|
st.title("Takalama - AI Chat") |
|
|
|
|
|
index, chunks = load_index_and_chunks() |
|
embed_model = load_embedding_model() |
|
qa_pipeline = load_qa_pipeline() |
|
curated_qa_pairs = load_curated_qa_pairs() |
|
|
|
display_conversation() |
|
|
|
|
|
user_query = st.chat_input("Ask a question about the document...") |
|
if user_query: |
|
add_to_conversation("user", user_query) |
|
|
|
|
|
answer = None |
|
for pair in curated_qa_pairs: |
|
if user_query.lower() in pair["question"].lower(): |
|
answer = pair["answer"] |
|
break |
|
|
|
if not answer: |
|
|
|
query_embedding = embed_model.encode([user_query]).astype("float32") |
|
distances, indices = index.search(query_embedding, 3) |
|
pdf_context = "\n".join(chunks[idx] for idx in indices[0]) |
|
|
|
|
|
response = qa_pipeline(question=user_query, context=pdf_context) |
|
answer = response.get("answer", "I couldn't find an answer to that.") |
|
|
|
add_to_conversation("assistant", answer) |
|
st.rerun() |
|
|