File size: 3,285 Bytes
ae479fd 076c725 39d1fa0 bd4d03b 39d1fa0 1d9b2c2 d8ffd44 076c725 6170969 226641d 076c725 bd4d03b d8ffd44 076c725 411f496 076c725 bd4d03b 076c725 bd4d03b 076c725 6170969 bd4d03b 411f496 bd4d03b 1bfc32c 411f496 226641d d8ffd44 bd4d03b d8ffd44 bd4d03b 1bfc32c bd4d03b 226641d bd4d03b 39d1fa0 bd4d03b d8ffd44 bd4d03b d8ffd44 bd4d03b d8ffd44 bd4d03b 076c725 39d1fa0 bd4d03b 39d1fa0 bd4d03b 39d1fa0 bd4d03b 411f496 bd4d03b 1bfc32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import streamlit as st
import faiss
import numpy as np
import pickle
import json
import torch
from sentence_transformers import SentenceTransformer
from transformers import pipeline, RagTokenizer, RagRetriever, RagSequenceForGeneration
# ========================
# File Names & Model Names
# ========================
INDEX_FILE = "faiss_index.index"
CHUNKS_FILE = "chunks.pkl"
CURATED_QA_FILE = "curated_qa_pairs.json"
EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
QA_MODEL_NAME = "deepset/roberta-large-squad2"
# ========================
# Loading Functions (cached)
# ========================
@st.cache_resource
def load_index_and_chunks():
try:
index = faiss.read_index(INDEX_FILE)
with open(CHUNKS_FILE, "rb") as f:
chunks = pickle.load(f)
return index, chunks
except Exception as e:
st.error(f"Error loading FAISS index and chunks: {e}")
return None, None
@st.cache_resource
def load_embedding_model():
return SentenceTransformer(EMBEDDING_MODEL_NAME)
@st.cache_resource
def load_qa_pipeline():
return pipeline("question-answering", model=QA_MODEL_NAME, tokenizer=QA_MODEL_NAME)
@st.cache_resource
def load_curated_qa_pairs():
try:
with open(CURATED_QA_FILE, "r", encoding="utf-8") as f:
return json.load(f)
except Exception as e:
st.error(f"Error loading curated Q/A pairs: {e}")
return []
# ========================================
# Chatbot Interface & Conversation Handling
# ========================================
def display_conversation():
"""Displays conversation history in a structured chat format."""
for role, message in st.session_state.conversation_history:
with st.chat_message(role):
st.write(message)
def add_to_conversation(role, message):
"""Adds a message to conversation history."""
st.session_state.conversation_history.append((role, message))
# Initialize conversation history
if "conversation_history" not in st.session_state:
st.session_state.conversation_history = []
# ========================================
# Main Streamlit Chat UI
# ========================================
st.title("Takalama - AI Chat")
# Load models & data
index, chunks = load_index_and_chunks()
embed_model = load_embedding_model()
qa_pipeline = load_qa_pipeline()
curated_qa_pairs = load_curated_qa_pairs()
display_conversation()
# User Input
user_query = st.chat_input("Ask a question about the document...")
if user_query:
add_to_conversation("user", user_query)
# Check for curated Q/A pair
answer = None
for pair in curated_qa_pairs:
if user_query.lower() in pair["question"].lower():
answer = pair["answer"]
break
if not answer:
# Retrieve relevant context
query_embedding = embed_model.encode([user_query]).astype("float32")
distances, indices = index.search(query_embedding, 3)
pdf_context = "\n".join(chunks[idx] for idx in indices[0])
# Generate an answer using the QA pipeline
response = qa_pipeline(question=user_query, context=pdf_context)
answer = response.get("answer", "I couldn't find an answer to that.")
add_to_conversation("assistant", answer)
st.rerun()
|