Tamil Eniyan commited on
Commit
bd4d03b
·
1 Parent(s): a42855b

Updated application with optimizations

Browse files
Files changed (1) hide show
  1. app.py +54 -205
app.py CHANGED
@@ -3,30 +3,23 @@ import faiss
3
  import numpy as np
4
  import pickle
5
  import json
6
- from sentence_transformers import SentenceTransformer
7
- from transformers import (
8
- pipeline,
9
- RagTokenizer,
10
- RagRetriever,
11
- RagSequenceForGeneration,
12
- )
13
  import torch
 
 
14
 
15
  # ========================
16
  # File Names & Model Names
17
  # ========================
18
-
19
  INDEX_FILE = "faiss_index.index"
20
  CHUNKS_FILE = "chunks.pkl"
21
  CURATED_QA_FILE = "curated_qa_pairs.json"
22
 
23
  EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
24
- QA_MODEL_NAME = "deepset/roberta-large-squad2" # For the standard QA pipeline
25
 
26
  # ========================
27
  # Loading Functions (cached)
28
  # ========================
29
-
30
  @st.cache_resource
31
  def load_index_and_chunks():
32
  try:
@@ -40,217 +33,73 @@ def load_index_and_chunks():
40
 
41
  @st.cache_resource
42
  def load_embedding_model():
43
- try:
44
- model = SentenceTransformer(EMBEDDING_MODEL_NAME)
45
- return model
46
- except Exception as e:
47
- st.error(f"Error loading embedding model: {e}")
48
- return None
49
 
50
  @st.cache_resource
51
  def load_qa_pipeline():
52
- try:
53
- qa_pipe = pipeline("question-answering", model=QA_MODEL_NAME, tokenizer=QA_MODEL_NAME)
54
- return qa_pipe
55
- except Exception as e:
56
- st.error(f"Error loading QA pipeline: {e}")
57
- return None
58
 
59
  @st.cache_resource
60
- def load_curated_qa_pairs(json_file=CURATED_QA_FILE):
61
  try:
62
- with open(json_file, "r", encoding="utf-8") as f:
63
- curated_qa_pairs = json.load(f)
64
- return curated_qa_pairs
65
- except Exception as e:
66
- st.error(f"Error loading curated Q/A pairs from JSON: {e}")
67
  return []
68
 
69
  # ========================================
70
- # Standard: Retrieve Curated Q/A Pair Function
71
  # ========================================
72
 
73
- def get_curated_pair(query, curated_qa, embed_model, threshold=1.0):
74
- try:
75
- curated_questions = [qa["question"] for qa in curated_qa]
76
- query_embedding = embed_model.encode([query]).astype("float32")
77
- curated_embeddings = embed_model.encode(curated_questions, show_progress_bar=False)
78
- curated_embeddings = np.array(curated_embeddings).astype("float32")
79
-
80
- # Build a temporary FAISS index for the curated questions
81
- dimension = curated_embeddings.shape[1]
82
- curated_index = faiss.IndexFlatL2(dimension)
83
- curated_index.add(curated_embeddings)
84
-
85
- k = 1
86
- distances, indices = curated_index.search(query_embedding, k)
87
-
88
- if distances[0][0] < threshold:
89
- idx = indices[0][0]
90
- return curated_qa[idx]
91
- except Exception as e:
92
- st.error(f"Error retrieving curated Q/A pair: {e}")
93
- return None
94
 
95
- # ============================================================
96
- # Custom RAG Retriever: Uses your FAISS index & PDF passages
97
- # ============================================================
98
- class CustomRagRetriever(RagRetriever):
99
- """
100
- A custom retriever that uses your FAISS index and passages.
101
- It encodes the query with the provided embedding model,
102
- searches your FAISS index, and returns the top retrieved documents.
103
- """
104
- def __init__(self, config, faiss_index, passages, embed_model, tokenizer, n_docs=5):
105
- self.faiss_index = faiss_index # Your custom FAISS index of PDF embeddings
106
- self.passages = passages # List of PDF passage texts
107
- self.embed_model = embed_model # Embedding model used for encoding queries
108
- self.n_docs = n_docs # Number of top documents to retrieve
109
- self.tokenizer = tokenizer # Save tokenizer for internal use if needed
110
- # Override init_retrieval to bypass loading default passages.
111
- self.init_retrieval = lambda: None
112
- # Call the parent constructor with the required arguments.
113
- super().__init__(config, question_encoder_tokenizer=tokenizer, generator_tokenizer=tokenizer)
114
-
115
- def retrieve(self, query, n_docs=None):
116
- try:
117
- if n_docs is None:
118
- n_docs = self.n_docs
119
- # Encode the query using the embedding model
120
- query_embedding = self.embed_model.encode([query]).astype("float32")
121
- distances, indices = self.faiss_index.search(query_embedding, n_docs)
122
- retrieved_docs = [self.passages[i] for i in indices[0]]
123
- return {
124
- "doc_ids": indices,
125
- "doc_scores": distances,
126
- "retrieved_docs": retrieved_docs,
127
- }
128
- except Exception as e:
129
- st.error(f"Error in custom retrieval: {e}")
130
- return {"doc_ids": None, "doc_scores": None, "retrieved_docs": []}
131
 
132
- # ============================================================
133
- # Load RAG Model with Custom Retriever (cached for performance)
134
- # ============================================================
135
- @st.cache_resource
136
- def load_rag_model(_faiss_index, passages, _embed_model):
137
- try:
138
- tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
139
- rag_model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq")
140
-
141
- custom_retriever = CustomRagRetriever(
142
- config=rag_model.config,
143
- faiss_index=_faiss_index,
144
- passages=passages,
145
- embed_model=_embed_model,
146
- tokenizer=tokenizer,
147
- n_docs=5
148
- )
149
- rag_model.set_retriever(custom_retriever)
150
- return tokenizer, rag_model
151
- except Exception as e:
152
- st.error(f"Error loading RAG model with custom retriever: {e}")
153
- return None, None
154
-
155
- def generate_rag_answer(query, tokenizer, rag_model):
156
- try:
157
- inputs = tokenizer(query, return_tensors="pt")
158
- with torch.no_grad():
159
- generated_ids = rag_model.generate(**inputs)
160
- answer = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
161
- return answer
162
- except Exception as e:
163
- st.error(f"Error generating answer with RAG model: {e}")
164
- return ""
165
 
166
  # ========================================
167
- # Main Streamlit App
168
  # ========================================
169
- def main():
170
- st.title("Takalama")
171
-
172
- if "conversation_history" not in st.session_state:
173
- st.session_state.conversation_history = ""
174
-
175
- with st.spinner("Loading index and passages..."):
176
- index, chunks = load_index_and_chunks()
177
- if index is None or chunks is None:
178
- return
179
-
180
- with st.spinner("Loading embedding model..."):
181
- embed_model = load_embedding_model()
182
- if embed_model is None:
183
- return
184
-
185
- with st.spinner("Loading QA pipeline..."):
186
- qa_pipeline = load_qa_pipeline()
187
- if qa_pipeline is None:
188
- return
189
-
190
- with st.spinner("Loading curated Q/A pairs..."):
191
- curated_qa_pairs = load_curated_qa_pairs()
192
 
193
- st.write("Enter your question about the PDF document:")
194
- query = st.text_input("Question:")
 
 
 
 
195
 
196
- if query:
197
- st.session_state.conversation_history += f"User: {query}\n"
198
-
199
- with st.spinner("Retrieving relevant PDF context..."):
200
- try:
201
- query_embedding = embed_model.encode([query]).astype("float32")
202
- k = 3 # Retrieve top 3 chunks
203
- distances, indices = index.search(query_embedding, k)
204
- pdf_context = ""
205
- for idx in indices[0]:
206
- pdf_context += chunks[idx] + "\n"
207
- except Exception as e:
208
- st.error(f"Error retrieving PDF context: {e}")
209
- return
210
-
211
- base_context = st.session_state.conversation_history + "\n"
212
-
213
- # --- Option 1: Use RAG Model with Custom Retriever ---
214
- if st.button("Use RAG Model with Custom Retriever"):
215
- with st.spinner("Generating answer using RAG model..."):
216
- tokenizer_rag, rag_model = load_rag_model(index, chunks, embed_model)
217
- if tokenizer_rag is None or rag_model is None:
218
- return
219
- rag_answer = generate_rag_answer(query, tokenizer_rag, rag_model)
220
- st.write("**RAG Model Answer:**")
221
- st.write(rag_answer)
222
- st.session_state.conversation_history += f"AI (RAG): {rag_answer}\n"
223
- return
224
-
225
- # --- Option 2: Use Standard QA Pipeline with Curated Q/A Pairs ---
226
- with st.spinner("Checking for curated Q/A pair..."):
227
- curated_pair = get_curated_pair(query, curated_qa_pairs, embed_model)
228
 
229
- if curated_pair:
230
- st.info("A curated Q/A pair was found and will be used for the answer by default.")
231
- use_full_data = st.checkbox("High Reasoning", value=False)
232
- if not use_full_data:
233
- answer = curated_pair["answer"]
234
- st.write(answer)
235
- st.session_state.conversation_history += f"AI: {answer}\n"
236
- return
237
- else:
238
- context_to_use = base_context + pdf_context
239
- else:
240
- context_to_use = base_context + pdf_context
241
-
242
- with st.expander("Show Full PDF Context"):
243
- st.write(pdf_context)
244
-
245
- st.subheader("Answer:")
246
- with st.spinner("Generating answer using standard QA pipeline..."):
247
- try:
248
- result = qa_pipeline(question=query, context=context_to_use)
249
- answer = result["answer"]
250
- st.write(answer)
251
- st.session_state.conversation_history += f"AI: {answer}\n"
252
- except Exception as e:
253
- st.error(f"Error generating answer using QA pipeline: {e}")
254
-
255
- if __name__ == "__main__":
256
- main()
 
3
  import numpy as np
4
  import pickle
5
  import json
 
 
 
 
 
 
 
6
  import torch
7
+ from sentence_transformers import SentenceTransformer
8
+ from transformers import pipeline, RagTokenizer, RagRetriever, RagSequenceForGeneration
9
 
10
  # ========================
11
  # File Names & Model Names
12
  # ========================
 
13
  INDEX_FILE = "faiss_index.index"
14
  CHUNKS_FILE = "chunks.pkl"
15
  CURATED_QA_FILE = "curated_qa_pairs.json"
16
 
17
  EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
18
+ QA_MODEL_NAME = "deepset/roberta-large-squad2"
19
 
20
  # ========================
21
  # Loading Functions (cached)
22
  # ========================
 
23
  @st.cache_resource
24
  def load_index_and_chunks():
25
  try:
 
33
 
34
  @st.cache_resource
35
  def load_embedding_model():
36
+ return SentenceTransformer(EMBEDDING_MODEL_NAME)
 
 
 
 
 
37
 
38
  @st.cache_resource
39
  def load_qa_pipeline():
40
+ return pipeline("question-answering", model=QA_MODEL_NAME, tokenizer=QA_MODEL_NAME)
 
 
 
 
 
41
 
42
  @st.cache_resource
43
+ def load_curated_qa_pairs():
44
  try:
45
+ with open(CURATED_QA_FILE, "r", encoding="utf-8") as f:
46
+ return json.load(f)
47
+ except:
 
 
48
  return []
49
 
50
  # ========================================
51
+ # Chatbot Interface & Conversation Handling
52
  # ========================================
53
 
54
+ def display_conversation():
55
+ """Displays conversation history in a structured chat format."""
56
+ for entry in st.session_state.conversation_history:
57
+ role, message = entry
58
+ with st.chat_message(role):
59
+ st.write(message)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
+ def add_to_conversation(role, message):
62
+ """Adds a message to conversation history."""
63
+ st.session_state.conversation_history.append((role, message))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
+ # Initialize conversation history
66
+ if "conversation_history" not in st.session_state:
67
+ st.session_state.conversation_history = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  # ========================================
70
+ # Main Streamlit Chat UI
71
  # ========================================
72
+ st.title("Takalama - AI Chat")
73
+
74
+ # Load models & data
75
+ index, chunks = load_index_and_chunks()
76
+ embed_model = load_embedding_model()
77
+ qa_pipeline = load_qa_pipeline()
78
+ curated_qa_pairs = load_curated_qa_pairs()
79
+
80
+ display_conversation()
81
+
82
+ # User Input
83
+ user_query = st.chat_input("Ask a question about the document...")
84
+ if user_query:
85
+ add_to_conversation("user", user_query)
 
 
 
 
 
 
 
 
 
86
 
87
+ # Check for curated Q/A pair
88
+ answer = None
89
+ for pair in curated_qa_pairs:
90
+ if user_query.lower() in pair["question"].lower():
91
+ answer = pair["answer"]
92
+ break
93
 
94
+ if not answer:
95
+ # Retrieve relevant context
96
+ query_embedding = embed_model.encode([user_query]).astype("float32")
97
+ distances, indices = index.search(query_embedding, 3)
98
+ pdf_context = "\n".join(chunks[idx] for idx in indices[0])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
+ # Generate an answer using the QA pipeline
101
+ response = qa_pipeline(question=user_query, context=pdf_context)
102
+ answer = response.get("answer", "I couldn't find an answer to that.")
103
+
104
+ add_to_conversation("assistant", answer)
105
+ st.rerun()