Tamil Eniyan commited on
Commit
411f496
·
1 Parent(s): d8ffd44

Add application file

Browse files
Files changed (1) hide show
  1. app.py +139 -93
app.py CHANGED
@@ -24,53 +24,68 @@ QA_MODEL_NAME = "deepset/roberta-large-squad2" # For the standard QA pipeline
24
 
25
  @st.cache_resource
26
  def load_index_and_chunks():
27
- index = faiss.read_index(INDEX_FILE)
28
- with open(CHUNKS_FILE, "rb") as f:
29
- chunks = pickle.load(f)
30
- return index, chunks
 
 
 
 
31
 
32
  @st.cache_resource
33
  def load_embedding_model():
34
- return SentenceTransformer(EMBEDDING_MODEL_NAME)
 
 
 
 
 
35
 
36
  @st.cache_resource
37
  def load_qa_pipeline():
38
- return pipeline("question-answering", model=QA_MODEL_NAME, tokenizer=QA_MODEL_NAME)
 
 
 
 
 
39
 
40
  @st.cache_resource
41
  def load_curated_qa_pairs(json_file=CURATED_QA_FILE):
42
- with open(json_file, "r", encoding="utf-8") as f:
43
- curated_qa_pairs = json.load(f)
44
- return curated_qa_pairs
 
 
 
 
45
 
46
  # ========================================
47
  # Standard: Retrieve Curated Q/A Pair Function
48
  # ========================================
49
 
50
  def get_curated_pair(query, curated_qa, embed_model, threshold=1.0):
51
- """
52
- Retrieve the most relevant curated Q/A pair based on the user's query.
53
- Returns the QA dictionary if the similarity (using L2 distance) is below the threshold,
54
- otherwise returns None.
55
- """
56
- curated_questions = [qa["question"] for qa in curated_qa]
57
- query_embedding = embed_model.encode([query]).astype('float32')
58
- curated_embeddings = embed_model.encode(curated_questions, show_progress_bar=False)
59
- curated_embeddings = np.array(curated_embeddings).astype('float32')
60
-
61
- # Build a temporary FAISS index for the curated questions
62
- dimension = curated_embeddings.shape[1]
63
- curated_index = faiss.IndexFlatL2(dimension)
64
- curated_index.add(curated_embeddings)
65
-
66
- k = 1
67
- distances, indices = curated_index.search(query_embedding, k)
68
-
69
- if distances[0][0] < threshold:
70
- idx = indices[0][0]
71
- return curated_qa[idx]
72
- else:
73
- return None
74
 
75
  # ============================================================
76
  # Custom RAG Retriever: Uses your FAISS index & PDF passages
@@ -92,48 +107,56 @@ class CustomRagRetriever(RagRetriever):
92
  super().__init__(dummy_dataset, tokenizer=tokenizer, index_name="custom")
93
 
94
  def retrieve(self, query, n_docs=None):
95
- if n_docs is None:
96
- n_docs = self.n_docs
97
- # Encode the query using your embedding model
98
- query_embedding = self.embed_model.encode([query]).astype('float32')
99
- distances, indices = self.faiss_index.search(query_embedding, n_docs)
100
- # Retrieve the passages using the indices
101
- retrieved_docs = [self.passages[i] for i in indices[0]]
102
- return {
103
- "doc_ids": indices,
104
- "doc_scores": distances,
105
- "retrieved_docs": retrieved_docs,
106
- }
 
 
 
 
107
 
108
  # ============================================================
109
  # Load RAG Model with Custom Retriever (cached for performance)
110
  # ============================================================
111
 
112
  @st.cache_resource
113
- def load_rag_model(faiss_index, passages, embed_model):
114
- # Load the RAG tokenizer and base model
115
- tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
116
- rag_model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq")
117
-
118
- # Create the custom retriever using your FAISS index, passages, and embedding model
119
- custom_retriever = CustomRagRetriever(
120
- faiss_index=faiss_index,
121
- passages=passages,
122
- embed_model=embed_model,
123
- tokenizer=tokenizer,
124
- n_docs=5
125
- )
126
-
127
- # Set the custom retriever in the RAG model
128
- rag_model.set_retriever(custom_retriever)
129
- return tokenizer, rag_model
130
 
131
  def generate_rag_answer(query, tokenizer, rag_model):
132
- inputs = tokenizer(query, return_tensors="pt")
133
- with torch.no_grad():
134
- generated_ids = rag_model.generate(**inputs)
135
- answer = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
136
- return answer
 
 
 
 
137
 
138
  # ========================================
139
  # Main Streamlit App
@@ -146,11 +169,24 @@ def main():
146
  if 'conversation_history' not in st.session_state:
147
  st.session_state.conversation_history = ""
148
 
149
- # Load necessary data and models
150
- index, chunks = load_index_and_chunks()
151
- embed_model = load_embedding_model()
152
- qa_pipeline = load_qa_pipeline()
153
- curated_qa_pairs = load_curated_qa_pairs()
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
  st.write("Enter your question about the PDF document:")
156
  query = st.text_input("Question:")
@@ -160,29 +196,38 @@ def main():
160
  st.session_state.conversation_history += f"User: {query}\n"
161
 
162
  # Retrieve relevant PDF context using the FAISS index
163
- query_embedding = embed_model.encode([query]).astype('float32')
164
- k = 3 # Number of top chunks to retrieve
165
- distances, indices = index.search(query_embedding, k)
166
- pdf_context = ""
167
- for idx in indices[0]:
168
- pdf_context += chunks[idx] + "\n"
 
 
 
 
 
169
 
170
  base_context = st.session_state.conversation_history + "\n"
171
 
172
  # --- Option 1: Use RAG Model with Custom Retriever ---
173
  if st.button("Use RAG Model with Custom Retriever"):
174
- tokenizer_rag, rag_model = load_rag_model(index, chunks, embed_model)
175
- rag_answer = generate_rag_answer(query, tokenizer_rag, rag_model)
176
- st.write("**RAG Model Answer:**")
177
- st.write(rag_answer)
178
- st.session_state.conversation_history += f"AI (RAG): {rag_answer}\n"
179
- return # Exit after using the RAG answer
 
 
 
180
 
181
  # --- Option 2: Use Standard QA Pipeline with Curated Q/A Pairs ---
182
- # Check for a curated Q/A pair
183
- curated_pair = get_curated_pair(query, curated_qa_pairs, embed_model)
 
184
  if curated_pair:
185
- st.write("A curated Q/A pair was found and will be used for the answer by default.")
186
  # Option to override with full PDF context ("High Reasoning")
187
  use_full_data = st.checkbox("High Reasoning", value=False)
188
  if not use_full_data:
@@ -200,13 +245,14 @@ def main():
200
  st.write(pdf_context)
201
 
202
  st.subheader("Answer:")
203
- try:
204
- result = qa_pipeline(question=query, context=context_to_use)
205
- answer = result["answer"]
206
- st.write(answer)
207
- st.session_state.conversation_history += f"AI: {answer}\n"
208
- except Exception as e:
209
- st.error(f"Error generating answer: {e}")
 
210
 
211
  if __name__ == "__main__":
212
  main()
 
24
 
25
  @st.cache_resource
26
  def load_index_and_chunks():
27
+ try:
28
+ index = faiss.read_index(INDEX_FILE)
29
+ with open(CHUNKS_FILE, "rb") as f:
30
+ chunks = pickle.load(f)
31
+ return index, chunks
32
+ except Exception as e:
33
+ st.error(f"Error loading FAISS index and chunks: {e}")
34
+ return None, None
35
 
36
  @st.cache_resource
37
  def load_embedding_model():
38
+ try:
39
+ model = SentenceTransformer(EMBEDDING_MODEL_NAME)
40
+ return model
41
+ except Exception as e:
42
+ st.error(f"Error loading embedding model: {e}")
43
+ return None
44
 
45
  @st.cache_resource
46
  def load_qa_pipeline():
47
+ try:
48
+ qa_pipe = pipeline("question-answering", model=QA_MODEL_NAME, tokenizer=QA_MODEL_NAME)
49
+ return qa_pipe
50
+ except Exception as e:
51
+ st.error(f"Error loading QA pipeline: {e}")
52
+ return None
53
 
54
  @st.cache_resource
55
  def load_curated_qa_pairs(json_file=CURATED_QA_FILE):
56
+ try:
57
+ with open(json_file, "r", encoding="utf-8") as f:
58
+ curated_qa_pairs = json.load(f)
59
+ return curated_qa_pairs
60
+ except Exception as e:
61
+ st.error(f"Error loading curated Q/A pairs from JSON: {e}")
62
+ return []
63
 
64
  # ========================================
65
  # Standard: Retrieve Curated Q/A Pair Function
66
  # ========================================
67
 
68
  def get_curated_pair(query, curated_qa, embed_model, threshold=1.0):
69
+ try:
70
+ curated_questions = [qa["question"] for qa in curated_qa]
71
+ query_embedding = embed_model.encode([query]).astype('float32')
72
+ curated_embeddings = embed_model.encode(curated_questions, show_progress_bar=False)
73
+ curated_embeddings = np.array(curated_embeddings).astype('float32')
74
+
75
+ # Build a temporary FAISS index for the curated questions
76
+ dimension = curated_embeddings.shape[1]
77
+ curated_index = faiss.IndexFlatL2(dimension)
78
+ curated_index.add(curated_embeddings)
79
+
80
+ k = 1
81
+ distances, indices = curated_index.search(query_embedding, k)
82
+
83
+ if distances[0][0] < threshold:
84
+ idx = indices[0][0]
85
+ return curated_qa[idx]
86
+ except Exception as e:
87
+ st.error(f"Error retrieving curated Q/A pair: {e}")
88
+ return None
 
 
 
89
 
90
  # ============================================================
91
  # Custom RAG Retriever: Uses your FAISS index & PDF passages
 
107
  super().__init__(dummy_dataset, tokenizer=tokenizer, index_name="custom")
108
 
109
  def retrieve(self, query, n_docs=None):
110
+ try:
111
+ if n_docs is None:
112
+ n_docs = self.n_docs
113
+ # Encode the query using your embedding model
114
+ query_embedding = self.embed_model.encode([query]).astype('float32')
115
+ distances, indices = self.faiss_index.search(query_embedding, n_docs)
116
+ # Retrieve the passages using the indices
117
+ retrieved_docs = [self.passages[i] for i in indices[0]]
118
+ return {
119
+ "doc_ids": indices,
120
+ "doc_scores": distances,
121
+ "retrieved_docs": retrieved_docs,
122
+ }
123
+ except Exception as e:
124
+ st.error(f"Error in custom retrieval: {e}")
125
+ return {"doc_ids": None, "doc_scores": None, "retrieved_docs": []}
126
 
127
  # ============================================================
128
  # Load RAG Model with Custom Retriever (cached for performance)
129
  # ============================================================
130
 
131
  @st.cache_resource
132
+ def load_rag_model(_faiss_index, passages, embed_model):
133
+ try:
134
+ tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
135
+ rag_model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq")
136
+
137
+ custom_retriever = CustomRagRetriever(
138
+ faiss_index=_faiss_index,
139
+ passages=passages,
140
+ embed_model=embed_model,
141
+ tokenizer=tokenizer,
142
+ n_docs=5
143
+ )
144
+ rag_model.set_retriever(custom_retriever)
145
+ return tokenizer, rag_model
146
+ except Exception as e:
147
+ st.error(f"Error loading RAG model with custom retriever: {e}")
148
+ return None, None
149
 
150
  def generate_rag_answer(query, tokenizer, rag_model):
151
+ try:
152
+ inputs = tokenizer(query, return_tensors="pt")
153
+ with torch.no_grad():
154
+ generated_ids = rag_model.generate(**inputs)
155
+ answer = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
156
+ return answer
157
+ except Exception as e:
158
+ st.error(f"Error generating answer with RAG model: {e}")
159
+ return ""
160
 
161
  # ========================================
162
  # Main Streamlit App
 
169
  if 'conversation_history' not in st.session_state:
170
  st.session_state.conversation_history = ""
171
 
172
+ # Load necessary data and models with spinners for responsiveness
173
+ with st.spinner("Loading index and passages..."):
174
+ index, chunks = load_index_and_chunks()
175
+ if index is None or chunks is None:
176
+ return
177
+
178
+ with st.spinner("Loading embedding model..."):
179
+ embed_model = load_embedding_model()
180
+ if embed_model is None:
181
+ return
182
+
183
+ with st.spinner("Loading QA pipeline..."):
184
+ qa_pipeline = load_qa_pipeline()
185
+ if qa_pipeline is None:
186
+ return
187
+
188
+ with st.spinner("Loading curated Q/A pairs..."):
189
+ curated_qa_pairs = load_curated_qa_pairs()
190
 
191
  st.write("Enter your question about the PDF document:")
192
  query = st.text_input("Question:")
 
196
  st.session_state.conversation_history += f"User: {query}\n"
197
 
198
  # Retrieve relevant PDF context using the FAISS index
199
+ with st.spinner("Retrieving relevant PDF context..."):
200
+ try:
201
+ query_embedding = embed_model.encode([query]).astype('float32')
202
+ k = 3 # Number of top chunks to retrieve
203
+ distances, indices = index.search(query_embedding, k)
204
+ pdf_context = ""
205
+ for idx in indices[0]:
206
+ pdf_context += chunks[idx] + "\n"
207
+ except Exception as e:
208
+ st.error(f"Error retrieving PDF context: {e}")
209
+ return
210
 
211
  base_context = st.session_state.conversation_history + "\n"
212
 
213
  # --- Option 1: Use RAG Model with Custom Retriever ---
214
  if st.button("Use RAG Model with Custom Retriever"):
215
+ with st.spinner("Generating answer using RAG model..."):
216
+ tokenizer_rag, rag_model = load_rag_model(index, chunks, embed_model)
217
+ if tokenizer_rag is None or rag_model is None:
218
+ return
219
+ rag_answer = generate_rag_answer(query, tokenizer_rag, rag_model)
220
+ st.write("**RAG Model Answer:**")
221
+ st.write(rag_answer)
222
+ st.session_state.conversation_history += f"AI (RAG): {rag_answer}\n"
223
+ return # Exit after using the RAG answer
224
 
225
  # --- Option 2: Use Standard QA Pipeline with Curated Q/A Pairs ---
226
+ with st.spinner("Checking for curated Q/A pair..."):
227
+ curated_pair = get_curated_pair(query, curated_qa_pairs, embed_model)
228
+
229
  if curated_pair:
230
+ st.info("A curated Q/A pair was found and will be used for the answer by default.")
231
  # Option to override with full PDF context ("High Reasoning")
232
  use_full_data = st.checkbox("High Reasoning", value=False)
233
  if not use_full_data:
 
245
  st.write(pdf_context)
246
 
247
  st.subheader("Answer:")
248
+ with st.spinner("Generating answer using standard QA pipeline..."):
249
+ try:
250
+ result = qa_pipeline(question=query, context=context_to_use)
251
+ answer = result["answer"]
252
+ st.write(answer)
253
+ st.session_state.conversation_history += f"AI: {answer}\n"
254
+ except Exception as e:
255
+ st.error(f"Error generating answer using QA pipeline: {e}")
256
 
257
  if __name__ == "__main__":
258
  main()