Tamil Eniyan commited on
Commit
d8ffd44
·
1 Parent(s): 6170969

Add application file

Browse files
Files changed (1) hide show
  1. app.py +98 -13
app.py CHANGED
@@ -4,16 +4,23 @@ import numpy as np
4
  import pickle
5
  import json
6
  from sentence_transformers import SentenceTransformer
7
- from transformers import pipeline
 
 
 
 
 
8
 
9
- # File names for saved PDF-based data and curated Q/A pairs
10
  INDEX_FILE = "faiss_index.index"
11
  CHUNKS_FILE = "chunks.pkl"
12
  CURATED_QA_FILE = "curated_qa_pairs.json"
13
 
14
- # Models
15
  EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
16
- QA_MODEL_NAME = "deepset/roberta-large-squad2" # More powerful QA model
 
 
 
 
17
 
18
  @st.cache_resource
19
  def load_index_and_chunks():
@@ -36,6 +43,10 @@ def load_curated_qa_pairs(json_file=CURATED_QA_FILE):
36
  curated_qa_pairs = json.load(f)
37
  return curated_qa_pairs
38
 
 
 
 
 
39
  def get_curated_pair(query, curated_qa, embed_model, threshold=1.0):
40
  """
41
  Retrieve the most relevant curated Q/A pair based on the user's query.
@@ -61,14 +72,81 @@ def get_curated_pair(query, curated_qa, embed_model, threshold=1.0):
61
  else:
62
  return None
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def main():
65
  st.title("Takalama")
66
 
67
- # Initialize conversation history if not already set
68
  if 'conversation_history' not in st.session_state:
69
  st.session_state.conversation_history = ""
70
 
71
- # Load PDF index, chunks, models, and curated Q/A pairs
72
  index, chunks = load_index_and_chunks()
73
  embed_model = load_embedding_model()
74
  qa_pipeline = load_qa_pipeline()
@@ -78,39 +156,46 @@ def main():
78
  query = st.text_input("Question:")
79
 
80
  if query:
81
- # Append the current question to conversation history
82
  st.session_state.conversation_history += f"User: {query}\n"
83
 
84
  # Retrieve relevant PDF context using the FAISS index
85
  query_embedding = embed_model.encode([query]).astype('float32')
86
  k = 3 # Number of top chunks to retrieve
87
  distances, indices = index.search(query_embedding, k)
88
-
89
  pdf_context = ""
90
  for idx in indices[0]:
91
  pdf_context += chunks[idx] + "\n"
92
 
93
  base_context = st.session_state.conversation_history + "\n"
94
 
 
 
 
 
 
 
 
 
 
 
95
  # Check for a curated Q/A pair
96
  curated_pair = get_curated_pair(query, curated_qa_pairs, embed_model)
97
-
98
  if curated_pair:
99
  st.write("A curated Q/A pair was found and will be used for the answer by default.")
100
- # Option to override with full PDF context, now labeled as "High Reasoning"
101
  use_full_data = st.checkbox("High Reasoning", value=False)
102
  if not use_full_data:
103
- # Directly display the curated answer without running the QA pipeline
104
  answer = curated_pair["answer"]
105
  st.write(answer)
106
  st.session_state.conversation_history += f"AI: {answer}\n"
107
- return # Exit the function after displaying the curated answer
108
  else:
109
  context_to_use = base_context + pdf_context
110
  else:
111
  context_to_use = base_context + pdf_context
112
 
113
- # Provide an expander to show the full PDF context if desired
114
  with st.expander("Show Full PDF Context"):
115
  st.write(pdf_context)
116
 
 
4
  import pickle
5
  import json
6
  from sentence_transformers import SentenceTransformer
7
+ from transformers import pipeline, RagTokenizer, RagRetriever, RagSequenceForGeneration
8
+ import torch
9
+
10
+ # ========================
11
+ # File Names & Model Names
12
+ # ========================
13
 
 
14
  INDEX_FILE = "faiss_index.index"
15
  CHUNKS_FILE = "chunks.pkl"
16
  CURATED_QA_FILE = "curated_qa_pairs.json"
17
 
 
18
  EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
19
+ QA_MODEL_NAME = "deepset/roberta-large-squad2" # For the standard QA pipeline
20
+
21
+ # ========================
22
+ # Loading Functions (cached)
23
+ # ========================
24
 
25
  @st.cache_resource
26
  def load_index_and_chunks():
 
43
  curated_qa_pairs = json.load(f)
44
  return curated_qa_pairs
45
 
46
+ # ========================================
47
+ # Standard: Retrieve Curated Q/A Pair Function
48
+ # ========================================
49
+
50
  def get_curated_pair(query, curated_qa, embed_model, threshold=1.0):
51
  """
52
  Retrieve the most relevant curated Q/A pair based on the user's query.
 
72
  else:
73
  return None
74
 
75
+ # ============================================================
76
+ # Custom RAG Retriever: Uses your FAISS index & PDF passages
77
+ # ============================================================
78
+
79
+ class CustomRagRetriever(RagRetriever):
80
+ """
81
+ A custom retriever that uses your FAISS index and passages.
82
+ It encodes the query with the provided embedding model,
83
+ searches your FAISS index, and returns the top retrieved documents.
84
+ """
85
+ def __init__(self, faiss_index, passages, embed_model, tokenizer, n_docs=5):
86
+ self.faiss_index = faiss_index # FAISS index of PDF embeddings
87
+ self.passages = passages # List of passage texts from your PDF
88
+ self.embed_model = embed_model # Embedding model to encode queries
89
+ self.n_docs = n_docs # Number of top documents to retrieve
90
+ # Provide a dummy dataset to satisfy the base class.
91
+ dummy_dataset = {"title": ["dummy"], "text": ["dummy"]}
92
+ super().__init__(dummy_dataset, tokenizer=tokenizer, index_name="custom")
93
+
94
+ def retrieve(self, query, n_docs=None):
95
+ if n_docs is None:
96
+ n_docs = self.n_docs
97
+ # Encode the query using your embedding model
98
+ query_embedding = self.embed_model.encode([query]).astype('float32')
99
+ distances, indices = self.faiss_index.search(query_embedding, n_docs)
100
+ # Retrieve the passages using the indices
101
+ retrieved_docs = [self.passages[i] for i in indices[0]]
102
+ return {
103
+ "doc_ids": indices,
104
+ "doc_scores": distances,
105
+ "retrieved_docs": retrieved_docs,
106
+ }
107
+
108
+ # ============================================================
109
+ # Load RAG Model with Custom Retriever (cached for performance)
110
+ # ============================================================
111
+
112
+ @st.cache_resource
113
+ def load_rag_model(faiss_index, passages, embed_model):
114
+ # Load the RAG tokenizer and base model
115
+ tokenizer = RagTokenizer.from_pretrained("facebook/rag-token-nq")
116
+ rag_model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq")
117
+
118
+ # Create the custom retriever using your FAISS index, passages, and embedding model
119
+ custom_retriever = CustomRagRetriever(
120
+ faiss_index=faiss_index,
121
+ passages=passages,
122
+ embed_model=embed_model,
123
+ tokenizer=tokenizer,
124
+ n_docs=5
125
+ )
126
+
127
+ # Set the custom retriever in the RAG model
128
+ rag_model.set_retriever(custom_retriever)
129
+ return tokenizer, rag_model
130
+
131
+ def generate_rag_answer(query, tokenizer, rag_model):
132
+ inputs = tokenizer(query, return_tensors="pt")
133
+ with torch.no_grad():
134
+ generated_ids = rag_model.generate(**inputs)
135
+ answer = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
136
+ return answer
137
+
138
+ # ========================================
139
+ # Main Streamlit App
140
+ # ========================================
141
+
142
  def main():
143
  st.title("Takalama")
144
 
145
+ # Initialize conversation history if not present
146
  if 'conversation_history' not in st.session_state:
147
  st.session_state.conversation_history = ""
148
 
149
+ # Load necessary data and models
150
  index, chunks = load_index_and_chunks()
151
  embed_model = load_embedding_model()
152
  qa_pipeline = load_qa_pipeline()
 
156
  query = st.text_input("Question:")
157
 
158
  if query:
159
+ # Append question to conversation history
160
  st.session_state.conversation_history += f"User: {query}\n"
161
 
162
  # Retrieve relevant PDF context using the FAISS index
163
  query_embedding = embed_model.encode([query]).astype('float32')
164
  k = 3 # Number of top chunks to retrieve
165
  distances, indices = index.search(query_embedding, k)
 
166
  pdf_context = ""
167
  for idx in indices[0]:
168
  pdf_context += chunks[idx] + "\n"
169
 
170
  base_context = st.session_state.conversation_history + "\n"
171
 
172
+ # --- Option 1: Use RAG Model with Custom Retriever ---
173
+ if st.button("Use RAG Model with Custom Retriever"):
174
+ tokenizer_rag, rag_model = load_rag_model(index, chunks, embed_model)
175
+ rag_answer = generate_rag_answer(query, tokenizer_rag, rag_model)
176
+ st.write("**RAG Model Answer:**")
177
+ st.write(rag_answer)
178
+ st.session_state.conversation_history += f"AI (RAG): {rag_answer}\n"
179
+ return # Exit after using the RAG answer
180
+
181
+ # --- Option 2: Use Standard QA Pipeline with Curated Q/A Pairs ---
182
  # Check for a curated Q/A pair
183
  curated_pair = get_curated_pair(query, curated_qa_pairs, embed_model)
 
184
  if curated_pair:
185
  st.write("A curated Q/A pair was found and will be used for the answer by default.")
186
+ # Option to override with full PDF context ("High Reasoning")
187
  use_full_data = st.checkbox("High Reasoning", value=False)
188
  if not use_full_data:
 
189
  answer = curated_pair["answer"]
190
  st.write(answer)
191
  st.session_state.conversation_history += f"AI: {answer}\n"
192
+ return # Exit after displaying the curated answer
193
  else:
194
  context_to_use = base_context + pdf_context
195
  else:
196
  context_to_use = base_context + pdf_context
197
 
198
+ # Optionally, show the full PDF context in an expander
199
  with st.expander("Show Full PDF Context"):
200
  st.write(pdf_context)
201