File size: 3,285 Bytes
ae479fd
076c725
 
 
39d1fa0
 
bd4d03b
39d1fa0
1d9b2c2
d8ffd44
 
 
076c725
 
6170969
226641d
076c725
bd4d03b
d8ffd44
 
 
 
076c725
 
411f496
 
 
 
 
 
 
 
076c725
 
 
bd4d03b
076c725
 
 
bd4d03b
076c725
6170969
bd4d03b
411f496
bd4d03b
 
1bfc32c
 
411f496
226641d
d8ffd44
bd4d03b
d8ffd44
bd4d03b
 
1bfc32c
bd4d03b
 
226641d
bd4d03b
39d1fa0
bd4d03b
d8ffd44
bd4d03b
 
 
d8ffd44
 
bd4d03b
d8ffd44
bd4d03b
 
 
 
 
 
 
 
 
 
 
 
 
 
076c725
39d1fa0
bd4d03b
 
 
 
 
39d1fa0
bd4d03b
39d1fa0
bd4d03b
 
 
411f496
bd4d03b
 
 
 
 
1bfc32c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import streamlit as st
import faiss
import numpy as np
import pickle
import json
import torch
from sentence_transformers import SentenceTransformer
from transformers import pipeline, RagTokenizer, RagRetriever, RagSequenceForGeneration

# ========================
# File Names & Model Names
# ========================
INDEX_FILE = "faiss_index.index"
CHUNKS_FILE = "chunks.pkl"
CURATED_QA_FILE = "curated_qa_pairs.json"

EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
QA_MODEL_NAME = "deepset/roberta-large-squad2"

# ========================
# Loading Functions (cached)
# ========================
@st.cache_resource
def load_index_and_chunks():
    try:
        index = faiss.read_index(INDEX_FILE)
        with open(CHUNKS_FILE, "rb") as f:
            chunks = pickle.load(f)
        return index, chunks
    except Exception as e:
        st.error(f"Error loading FAISS index and chunks: {e}")
        return None, None

@st.cache_resource
def load_embedding_model():
    return SentenceTransformer(EMBEDDING_MODEL_NAME)

@st.cache_resource
def load_qa_pipeline():
    return pipeline("question-answering", model=QA_MODEL_NAME, tokenizer=QA_MODEL_NAME)

@st.cache_resource
def load_curated_qa_pairs():
    try:
        with open(CURATED_QA_FILE, "r", encoding="utf-8") as f:
            return json.load(f)
    except Exception as e:
        st.error(f"Error loading curated Q/A pairs: {e}")
        return []

# ========================================
# Chatbot Interface & Conversation Handling
# ========================================
def display_conversation():
    """Displays conversation history in a structured chat format."""
    for role, message in st.session_state.conversation_history:
        with st.chat_message(role):
            st.write(message)

def add_to_conversation(role, message):
    """Adds a message to conversation history."""
    st.session_state.conversation_history.append((role, message))

# Initialize conversation history
if "conversation_history" not in st.session_state:
    st.session_state.conversation_history = []

# ========================================
# Main Streamlit Chat UI
# ========================================
st.title("Takalama - AI Chat")

# Load models & data
index, chunks = load_index_and_chunks()
embed_model = load_embedding_model()
qa_pipeline = load_qa_pipeline()
curated_qa_pairs = load_curated_qa_pairs()

display_conversation()

# User Input
user_query = st.chat_input("Ask a question about the document...")
if user_query:
    add_to_conversation("user", user_query)
    
    # Check for curated Q/A pair
    answer = None
    for pair in curated_qa_pairs:
        if user_query.lower() in pair["question"].lower():
            answer = pair["answer"]
            break
    
    if not answer:
        # Retrieve relevant context
        query_embedding = embed_model.encode([user_query]).astype("float32")
        distances, indices = index.search(query_embedding, 3)
        pdf_context = "\n".join(chunks[idx] for idx in indices[0])
        
        # Generate an answer using the QA pipeline
        response = qa_pipeline(question=user_query, context=pdf_context)
        answer = response.get("answer", "I couldn't find an answer to that.")
    
    add_to_conversation("assistant", answer)
    st.rerun()