File size: 9,859 Bytes
ae479fd
076c725
 
 
6170969
076c725
1acd638
 
 
 
 
 
d8ffd44
 
 
 
 
ae479fd
076c725
 
6170969
226641d
076c725
d8ffd44
 
 
 
 
076c725
 
 
411f496
 
 
 
 
 
 
 
076c725
 
 
411f496
 
 
 
 
 
076c725
 
 
411f496
 
 
 
 
 
076c725
6170969
 
411f496
 
 
 
 
 
 
226641d
d8ffd44
 
 
 
bebbe8f
411f496
 
90ada56
411f496
90ada56
411f496
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226641d
d8ffd44
 
 
 
 
 
 
 
 
1273455
90ada56
 
 
d8ffd44
3969dd2
1acd638
 
 
1273455
d8ffd44
 
411f496
 
 
7b32796
90ada56
411f496
 
 
 
 
 
 
 
 
 
d8ffd44
 
 
 
 
7e569d6
411f496
 
 
 
 
1273455
411f496
 
7e569d6
411f496
 
 
 
 
 
 
 
d8ffd44
 
411f496
 
 
 
 
 
 
 
 
d8ffd44
 
 
 
076c725
c521252
076c725
90ada56
226641d
 
411f496
 
 
 
90ada56
411f496
 
 
 
90ada56
411f496
 
 
 
90ada56
411f496
 
076c725
 
 
 
 
226641d
076c725
411f496
 
90ada56
 
411f496
 
 
 
 
 
 
226641d
0282eea
20c970d
d8ffd44
 
411f496
 
 
 
 
 
 
 
90ada56
d8ffd44
 
411f496
 
 
bebbe8f
411f496
6170969
bebbe8f
 
 
 
90ada56
0282eea
bebbe8f
0282eea
 
 
 
 
076c725
 
411f496
 
 
 
 
 
 
 
076c725
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
import streamlit as st
import faiss
import numpy as np
import pickle
import json
from sentence_transformers import SentenceTransformer
from transformers import (
    pipeline,
    RagTokenizer,
    RagRetriever,
    RagSequenceForGeneration,
)
import torch

# ========================
# File Names & Model Names
# ========================

INDEX_FILE = "faiss_index.index"
CHUNKS_FILE = "chunks.pkl"
CURATED_QA_FILE = "curated_qa_pairs.json"

EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
QA_MODEL_NAME = "deepset/roberta-large-squad2"  # For the standard QA pipeline

# ========================
# Loading Functions (cached)
# ========================

@st.cache_resource
def load_index_and_chunks():
    try:
        index = faiss.read_index(INDEX_FILE)
        with open(CHUNKS_FILE, "rb") as f:
            chunks = pickle.load(f)
        return index, chunks
    except Exception as e:
        st.error(f"Error loading FAISS index and chunks: {e}")
        return None, None

@st.cache_resource
def load_embedding_model():
    try:
        model = SentenceTransformer(EMBEDDING_MODEL_NAME)
        return model
    except Exception as e:
        st.error(f"Error loading embedding model: {e}")
        return None

@st.cache_resource
def load_qa_pipeline():
    try:
        qa_pipe = pipeline("question-answering", model=QA_MODEL_NAME, tokenizer=QA_MODEL_NAME)
        return qa_pipe
    except Exception as e:
        st.error(f"Error loading QA pipeline: {e}")
        return None

@st.cache_resource
def load_curated_qa_pairs(json_file=CURATED_QA_FILE):
    try:
        with open(json_file, "r", encoding="utf-8") as f:
            curated_qa_pairs = json.load(f)
        return curated_qa_pairs
    except Exception as e:
        st.error(f"Error loading curated Q/A pairs from JSON: {e}")
        return []

# ========================================
# Standard: Retrieve Curated Q/A Pair Function
# ========================================

def get_curated_pair(query, curated_qa, embed_model, threshold=1.0):
    try:
        curated_questions = [qa["question"] for qa in curated_qa]
        query_embedding = embed_model.encode([query]).astype("float32")
        curated_embeddings = embed_model.encode(curated_questions, show_progress_bar=False)
        curated_embeddings = np.array(curated_embeddings).astype("float32")
        
        # Build a temporary FAISS index for the curated questions
        dimension = curated_embeddings.shape[1]
        curated_index = faiss.IndexFlatL2(dimension)
        curated_index.add(curated_embeddings)
        
        k = 1
        distances, indices = curated_index.search(query_embedding, k)
        
        if distances[0][0] < threshold:
            idx = indices[0][0]
            return curated_qa[idx]
    except Exception as e:
        st.error(f"Error retrieving curated Q/A pair: {e}")
    return None

# ============================================================
# Custom RAG Retriever: Uses your FAISS index & PDF passages
# ============================================================
class CustomRagRetriever(RagRetriever):
    """
    A custom retriever that uses your FAISS index and passages.
    It encodes the query with the provided embedding model,
    searches your FAISS index, and returns the top retrieved documents.
    """
    def __init__(self, config, faiss_index, passages, embed_model, tokenizer, n_docs=5):
        self.faiss_index = faiss_index      # Your custom FAISS index of PDF embeddings
        self.passages = passages            # List of PDF passage texts
        self.embed_model = embed_model      # Embedding model used for encoding queries
        self.n_docs = n_docs                # Number of top documents to retrieve
        self.tokenizer = tokenizer          # Save tokenizer for internal use if needed
        # Override init_retrieval to bypass loading default passages.
        self.init_retrieval = lambda: None
        # Call the parent constructor with the required arguments.
        super().__init__(config, question_encoder_tokenizer=tokenizer, generator_tokenizer=tokenizer)
    
    def retrieve(self, query, n_docs=None):
        try:
            if n_docs is None:
                n_docs = self.n_docs
            # Encode the query using the embedding model
            query_embedding = self.embed_model.encode([query]).astype("float32")
            distances, indices = self.faiss_index.search(query_embedding, n_docs)
            retrieved_docs = [self.passages[i] for i in indices[0]]
            return {
                "doc_ids": indices,
                "doc_scores": distances,
                "retrieved_docs": retrieved_docs,
            }
        except Exception as e:
            st.error(f"Error in custom retrieval: {e}")
            return {"doc_ids": None, "doc_scores": None, "retrieved_docs": []}

# ============================================================
# Load RAG Model with Custom Retriever (cached for performance)
# ============================================================
@st.cache_resource
def load_rag_model(_faiss_index, passages, _embed_model):
    try:
        tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
        rag_model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq")
        
        custom_retriever = CustomRagRetriever(
            config=rag_model.config,
            faiss_index=_faiss_index,
            passages=passages,
            embed_model=_embed_model,
            tokenizer=tokenizer,
            n_docs=5
        )
        rag_model.set_retriever(custom_retriever)
        return tokenizer, rag_model
    except Exception as e:
        st.error(f"Error loading RAG model with custom retriever: {e}")
        return None, None

def generate_rag_answer(query, tokenizer, rag_model):
    try:
        inputs = tokenizer(query, return_tensors="pt")
        with torch.no_grad():
            generated_ids = rag_model.generate(**inputs)
        answer = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
        return answer
    except Exception as e:
        st.error(f"Error generating answer with RAG model: {e}")
        return ""

# ========================================
# Main Streamlit App
# ========================================
def main():
    st.title("Takalama")
    
    if "conversation_history" not in st.session_state:
        st.session_state.conversation_history = ""
    
    with st.spinner("Loading index and passages..."):
        index, chunks = load_index_and_chunks()
        if index is None or chunks is None:
            return
    
    with st.spinner("Loading embedding model..."):
        embed_model = load_embedding_model()
        if embed_model is None:
            return
    
    with st.spinner("Loading QA pipeline..."):
        qa_pipeline = load_qa_pipeline()
        if qa_pipeline is None:
            return
    
    with st.spinner("Loading curated Q/A pairs..."):
        curated_qa_pairs = load_curated_qa_pairs()
    
    st.write("Enter your question about the PDF document:")
    query = st.text_input("Question:")
    
    if query:
        st.session_state.conversation_history += f"User: {query}\n"
        
        with st.spinner("Retrieving relevant PDF context..."):
            try:
                query_embedding = embed_model.encode([query]).astype("float32")
                k = 3  # Retrieve top 3 chunks
                distances, indices = index.search(query_embedding, k)
                pdf_context = ""
                for idx in indices[0]:
                    pdf_context += chunks[idx] + "\n"
            except Exception as e:
                st.error(f"Error retrieving PDF context: {e}")
                return
        
        base_context = st.session_state.conversation_history + "\n"
        
        # --- Option 1: Use RAG Model with Custom Retriever ---
        if st.button("Use RAG Model with Custom Retriever"):
            with st.spinner("Generating answer using RAG model..."):
                tokenizer_rag, rag_model = load_rag_model(index, chunks, embed_model)
                if tokenizer_rag is None or rag_model is None:
                    return
                rag_answer = generate_rag_answer(query, tokenizer_rag, rag_model)
                st.write("**RAG Model Answer:**")
                st.write(rag_answer)
                st.session_state.conversation_history += f"AI (RAG): {rag_answer}\n"
                return
        
        # --- Option 2: Use Standard QA Pipeline with Curated Q/A Pairs ---
        with st.spinner("Checking for curated Q/A pair..."):
            curated_pair = get_curated_pair(query, curated_qa_pairs, embed_model)
        
        if curated_pair:
            st.info("A curated Q/A pair was found and will be used for the answer by default.")
            use_full_data = st.checkbox("High Reasoning", value=False)
            if not use_full_data:
                answer = curated_pair["answer"]
                st.write(answer)
                st.session_state.conversation_history += f"AI: {answer}\n"
                return
            else:
                context_to_use = base_context + pdf_context
        else:
            context_to_use = base_context + pdf_context
        
        with st.expander("Show Full PDF Context"):
            st.write(pdf_context)
        
        st.subheader("Answer:")
        with st.spinner("Generating answer using standard QA pipeline..."):
            try:
                result = qa_pipeline(question=query, context=context_to_use)
                answer = result["answer"]
                st.write(answer)
                st.session_state.conversation_history += f"AI: {answer}\n"
            except Exception as e:
                st.error(f"Error generating answer using QA pipeline: {e}")

if __name__ == "__main__":
    main()