rithvik213 commited on
Commit
5e31ea3
Β·
1 Parent(s): c009c1e

updated RAG pipeline and streamlit file

Browse files
Files changed (2) hide show
  1. RAG.py +30 -202
  2. streamlit_app.py +37 -21
RAG.py CHANGED
@@ -62,44 +62,32 @@ def extract_text_from_json(json_data: Dict) -> str:
62
  return " ".join(text_parts) if text_parts else "No content available"
63
 
64
  def rephrase_and_expand_query(query: str, llm: Any) -> str:
65
-
66
- # Use LLM to rewrite and expand a query for better alignment with archive metadata.
67
  prompt_template = PromptTemplate.from_template(
68
  """
69
  You are a professional librarian skilled at historical research.
70
- Your task is to improve and expand the following search query to better match metadata in a historical archive.
71
-
72
- - First, rewrite the query to improve clarity and fit how librarians would search.
73
- - Second, expand the query by adding related terms (synonyms, related concepts, historical terminology, etc.).
74
-
75
- Return your output strictly in this format (no extra explanation):
76
  <IMPROVED_QUERY>your improved query here</IMPROVED_QUERY>
77
  <EXPANDED_QUERY>your expanded query here</EXPANDED_QUERY>
78
 
79
  Original Query: {query}
80
  """
81
  )
82
-
83
  prompt = prompt_template.invoke({"query": query})
84
  response = llm.invoke(prompt)
85
 
86
- # Extract just the improved and expanded queries
87
  improved_match = re.search(r"<IMPROVED_QUERY>(.*?)</IMPROVED_QUERY>", response.content, re.DOTALL)
88
  expanded_match = re.search(r"<EXPANDED_QUERY>(.*?)</EXPANDED_QUERY>", response.content, re.DOTALL)
89
 
90
  improved_query = improved_match.group(1).strip() if improved_match else query
91
  expanded_query = expanded_match.group(1).strip() if expanded_match else ""
92
 
93
- final_query = f"{improved_query} {expanded_query}".strip()
94
-
95
- logging.info(f"Original Query: {query}")
96
- logging.info(f"Improved Query: {improved_query}")
97
- logging.info(f"Expanded Query: {expanded_query}")
98
- logging.info(f"Final Query for Retrieval: {final_query}")
99
-
100
- return final_query
101
-
102
 
 
 
 
103
 
104
  weights = {
105
  "title_info_primary_tsi": 1.5, # Titles should be prioritized
@@ -164,132 +152,13 @@ def get_metadata_from_api(document_ids: List[str]) -> Dict[str, Dict]:
164
  metadata_dict[doc_id] = extract_text_from_json(json_data)
165
  return metadata_dict
166
 
167
-
168
-
169
- """
170
- def rerank(documents: List[Document], query: str) -> List[Document]:
171
- \"\"\"Ingest more metadata. Rerank documents using BM25\"\"\"
172
- start = time.time()
173
- if not documents:
174
- return []
175
-
176
- full_docs = []
177
- seen_sources = set()
178
- meta_start = time.time()
179
- for doc in documents:
180
- source = doc.metadata.get('source')
181
- if not source or source in seen_sources:
182
- continue # Skip duplicate sources
183
- seen_sources.add(source)
184
-
185
- url = f"https://www.digitalcommonwealth.org/search/{source}"
186
- json_data = safe_get_json(f"{url}.json")
187
-
188
- if json_data:
189
- text_content = extract_text_from_json(json_data)
190
- if text_content: # Only add documents with actual content
191
- full_docs.append(Document(page_content=text_content, metadata={"source": source, "field": doc.metadata.get("field", ""), "URL": url}))
192
-
193
- logging.info(f"Took {time.time()-meta_start} seconds to retrieve all metadata")
194
- if not full_docs:
195
- return []
196
-
197
- # Create BM25 retriever with the processed documents
198
- bm25 = BM25Retriever.from_documents(full_docs, k=min(10, len(full_docs)))
199
- bm25_ranked_docs = bm25.invoke(query)
200
-
201
- ranked_docs = []
202
- for doc in bm25_ranked_docs:
203
- bm25_score = 1.0
204
-
205
- # Compute metadata multiplier
206
- metadata_multiplier = 1.0
207
- for field, weight in weights.items():
208
- if field in doc.metadata and doc.metadata[field]:
209
- metadata_multiplier += weight
210
-
211
- # Compute final score: BM25 weight * Metadata multiplier
212
- final_score = bm25_score * metadata_multiplier
213
- ranked_docs.append((doc, final_score))
214
-
215
- # Sort by final score
216
- ranked_docs.sort(key=lambda x: x[1], reverse=True)
217
-
218
- logging.info(f"Finished reranking: {time.time()-start}")
219
- return [doc for doc, _ in ranked_docs]
220
- """
221
-
222
- '''
223
  def rerank(documents: List[Document], query: str) -> List[Document]:
224
- """Retrieve metadata from the database and rerank using BM25"""
225
- start = time.time()
226
  if not documents:
227
  return []
228
 
229
- document_ids = [doc.metadata.get('source') for doc in documents if doc.metadata.get('source')]
230
-
231
- # Fetch metadata from PostgreSQL
232
- metadata_dict = get_metadata_from_db(document_ids)
233
-
234
- full_docs = []
235
- for doc in documents:
236
- doc_id = doc.metadata.get('source')
237
- metadata = metadata_dict.get(doc_id, {})
238
-
239
- if metadata:
240
- text_content = " ".join([
241
- metadata.get("title", ""),
242
- metadata.get("abstract", ""),
243
- " ".join(metadata.get("subjects", [])),
244
- metadata.get("institution", "")
245
- ]).strip()
246
-
247
-
248
- if text_content:
249
- full_docs.append(Document(page_content=text_content, metadata={
250
- "source": doc_id,
251
- "URL": metadata.get("metadata_url", ""),
252
- "image_url": metadata.get("image_url", "")
253
- }))
254
-
255
- logging.info(f"Took {time.time()-start} seconds to retrieve all metadata from PostgreSQL")
256
-
257
- if not full_docs:
258
- return []
259
-
260
- # Rerank using BM25
261
- bm25 = BM25Retriever.from_documents(full_docs, k=min(10, len(full_docs)))
262
- bm25_ranked_docs = bm25.invoke(query)
263
-
264
- ranked_docs = []
265
- for doc in bm25_ranked_docs:
266
- bm25_score = 1.0
267
-
268
- # Compute metadata multiplier
269
- metadata_multiplier = 1.0
270
- for field, weight in weights.items():
271
- if field in doc.metadata and doc.metadata[field]:
272
- metadata_multiplier += weight
273
-
274
- # Compute final score: BM25 weight * Metadata multiplier
275
- final_score = bm25_score * metadata_multiplier
276
- ranked_docs.append((doc, final_score))
277
-
278
- # Sort by final score
279
- ranked_docs.sort(key=lambda x: x[1], reverse=True)
280
-
281
- logging.info(f"Finished reranking: {time.time()-start}")
282
- return [doc for doc, _ in ranked_docs]
283
- '''
284
-
285
- def rerank(documents: List[Document], query: str) -> List[Document]:
286
- """Rerank using BM25 and enhance scores using document metadata."""
287
- start = time.time()
288
-
289
- if not documents:
290
- return []
291
 
292
- # Group document chunks by source_id
293
  grouped = defaultdict(list)
294
  for doc in documents:
295
  source_id = doc.metadata.get("source")
@@ -298,49 +167,39 @@ def rerank(documents: List[Document], query: str) -> List[Document]:
298
 
299
  full_docs = []
300
  for source_id, chunks in grouped.items():
301
- combined_text = " ".join([chunk.page_content for chunk in chunks if chunk.page_content])
302
- representative_metadata = chunks[0].metadata or {}
303
-
304
- #logging.debug(f"Metadata for doc {source_id}: {representative_metadata}")
305
-
306
- if combined_text.strip():
307
- full_docs.append(Document(
308
- page_content=combined_text.strip(),
309
- metadata={
310
- "source": source_id,
311
- "URL": representative_metadata.get("metadata_url", ""),
312
- "image_url": representative_metadata.get("image_url", ""),
313
- **representative_metadata # preserve all original fields
314
- }
315
- ))
316
-
317
- logging.info(f"Built {len(full_docs)} documents for reranking in {time.time() - start:.2f} seconds.")
318
 
319
  if not full_docs:
320
  return []
321
 
322
- # BM25 reranking
323
- bm25 = BM25Retriever.from_documents(full_docs, k=min(10, len(full_docs)))
324
  bm25_ranked_docs = bm25.invoke(query)
325
 
326
- # Score enhancement using metadata weights
327
  ranked_docs = []
328
  for doc in bm25_ranked_docs:
329
- bm25_score = 1.0 # BM25 returns sorted, so base score is 1
330
  metadata_multiplier = 1.0
 
331
  for field, weight in weights.items():
332
  if field in doc.metadata and doc.metadata[field]:
333
  metadata_multiplier += weight
 
 
 
 
 
 
 
334
  final_score = bm25_score * metadata_multiplier
335
  ranked_docs.append((doc, final_score))
336
 
337
- # Sort by enhanced score
338
  ranked_docs.sort(key=lambda x: x[1], reverse=True)
339
- logging.info(f"Finished reranking in {time.time() - start:.2f} seconds")
340
-
341
- return [doc for doc, _ in ranked_docs]
342
-
343
-
344
 
345
  def parse_xml_and_query(query:str,xml_string:str) -> str:
346
  """parse xml and return rephrased query"""
@@ -376,43 +235,12 @@ def RAG(llm: Any, query: str,vectorstore:PineconeVectorStore, top: int = 10, k:
376
 
377
  # Query alignment is commented our, however I have decided to leave it in for potential future use.
378
 
379
- # Retrieve initial documents using rephrased query -- not working as intended currently, maybe would be better for data with more words.
380
- # query_template = PromptTemplate.from_template(
381
- # """
382
- # Your job is to think about a query and then generate a statement that only includes information from the query that would answer the query.
383
- # You will be provided with a query in <QUERY></QUERY> tags.
384
- # Then you will think about what kind of information the query is looking for between <REASONING></REASONING> tags.
385
- # Then, based on the reasoning, you will generate a sample response to the query that only includes information from the query between <STATEMENT></STATEMENT> tags.
386
- # Afterwards, you will determine and reason about whether or not the statement you generated only includes information from the original query and would answer the query between <DETERMINATION></DETERMINATION> tags.
387
- # Finally, you will return a YES, or NO response between <VALID></VALID> tags based on whether or not you determined the statment to be valid.
388
- # Let me provide you with an exmaple:
389
-
390
- # <QUERY>I would really like to learn more about Bermudan geography<QUERY>
391
-
392
- # <REASONING>This query is interested in geograph as it relates to Bermuda. Some things they might be interested in are Bermudan climate, towns, cities, and geography</REASONING>
393
-
394
- # <STATEMENT>Bermuda's Climate is [blank]. Some of Bermuda's cities and towns are [blank]. Other points of interested about Bermuda's geography are [blank].</STATEMENT>
395
-
396
- # <DETERMINATION>The query originally only mentions bermuda and geography. The answers do not provide any false information, instead replacing meaningful responses with a placeholder [blank]. If it had hallucinated, it would not be valid. Because the statements do not hallucinate anything, this is a valid statement.</DETERMINATION>
397
-
398
- # <VALID>YES</VALID>
399
-
400
- # Now it's your turn! Remember not to hallucinate:
401
-
402
- # <QUERY>{query}</QUERY>
403
- # """
404
- # )
405
- # query_prompt = query_template.invoke({"query":query})
406
- # query_response = llm.invoke(query_prompt)
407
- # new_query = parse_xml_and_query(query=query,xml_string=query_response.content)
408
-
409
- #logging.info(f"\n---\nQUERY: {query}")
410
-
411
- #new query rephrasing
412
- #query = rephrase_and_expand_query(query, llm)
413
- #logging.info(f"\n---\nRephrased QUERY: {query}")
414
 
415
  retrieved, _ = retrieve(query=query, vectorstore=vectorstore, k=k)
 
416
  if not retrieved:
417
  return "No documents found for your query.", []
418
 
 
62
  return " ".join(text_parts) if text_parts else "No content available"
63
 
64
  def rephrase_and_expand_query(query: str, llm: Any) -> str:
65
+ """Use LLM to rewrite and expand a query for better alignment with archive metadata."""
 
66
  prompt_template = PromptTemplate.from_template(
67
  """
68
  You are a professional librarian skilled at historical research.
69
+ Rewrite and expand the query to match metadata tags. Include related terms (synonyms, historical names, places, events).
70
+
 
 
 
 
71
  <IMPROVED_QUERY>your improved query here</IMPROVED_QUERY>
72
  <EXPANDED_QUERY>your expanded query here</EXPANDED_QUERY>
73
 
74
  Original Query: {query}
75
  """
76
  )
 
77
  prompt = prompt_template.invoke({"query": query})
78
  response = llm.invoke(prompt)
79
 
 
80
  improved_match = re.search(r"<IMPROVED_QUERY>(.*?)</IMPROVED_QUERY>", response.content, re.DOTALL)
81
  expanded_match = re.search(r"<EXPANDED_QUERY>(.*?)</EXPANDED_QUERY>", response.content, re.DOTALL)
82
 
83
  improved_query = improved_match.group(1).strip() if improved_match else query
84
  expanded_query = expanded_match.group(1).strip() if expanded_match else ""
85
 
86
+ return f"{improved_query} {expanded_query}".strip()
 
 
 
 
 
 
 
 
87
 
88
+ def extract_years_from_query(query: str) -> List[str]:
89
+ """Extract 4-digit years from query for boosting."""
90
+ return re.findall(r"\b(1[5-9]\d{2}|20\d{2}|21\d{2}|22\d{2}|23\d{2})\b", query)
91
 
92
  weights = {
93
  "title_info_primary_tsi": 1.5, # Titles should be prioritized
 
152
  metadata_dict[doc_id] = extract_text_from_json(json_data)
153
  return metadata_dict
154
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  def rerank(documents: List[Document], query: str) -> List[Document]:
156
+ """Rerank documents using BM25 and metadata, boost if year matches."""
 
157
  if not documents:
158
  return []
159
 
160
+ query_years = extract_years_from_query(query)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
 
 
162
  grouped = defaultdict(list)
163
  for doc in documents:
164
  source_id = doc.metadata.get("source")
 
167
 
168
  full_docs = []
169
  for source_id, chunks in grouped.items():
170
+ combined_text = " ".join(chunk.page_content for chunk in chunks if chunk.page_content)
171
+ metadata = chunks[0].metadata if chunks else {}
172
+ full_docs.append(Document(
173
+ page_content=combined_text.strip(),
174
+ metadata={**metadata, "source": source_id}
175
+ ))
 
 
 
 
 
 
 
 
 
 
 
176
 
177
  if not full_docs:
178
  return []
179
 
180
+ bm25 = BM25Retriever.from_documents(full_docs, k=len(full_docs))
 
181
  bm25_ranked_docs = bm25.invoke(query)
182
 
 
183
  ranked_docs = []
184
  for doc in bm25_ranked_docs:
185
+ bm25_score = 1.0
186
  metadata_multiplier = 1.0
187
+
188
  for field, weight in weights.items():
189
  if field in doc.metadata and doc.metadata[field]:
190
  metadata_multiplier += weight
191
+
192
+ date_field = str(doc.metadata.get("date_tsim", ""))
193
+ for year in query_years:
194
+ if re.search(rf"\b{year}\b", date_field) or re.search(rf"{year[:-2]}\d{{2}}–{year[:-2]}\d{{2}}", date_field):
195
+ metadata_multiplier += 50
196
+ break
197
+
198
  final_score = bm25_score * metadata_multiplier
199
  ranked_docs.append((doc, final_score))
200
 
 
201
  ranked_docs.sort(key=lambda x: x[1], reverse=True)
202
+ return [doc for doc, _ in ranked_docs[:10]]
 
 
 
 
203
 
204
  def parse_xml_and_query(query:str,xml_string:str) -> str:
205
  """parse xml and return rephrased query"""
 
235
 
236
  # Query alignment is commented our, however I have decided to leave it in for potential future use.
237
 
238
+ # πŸ”„ Rephrase and expand the user query for better Pinecone matching
239
+ query = rephrase_and_expand_query(query, llm)
240
+ logging.info(f"Rephrased Query for Retrieval: {query}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
242
  retrieved, _ = retrieve(query=query, vectorstore=vectorstore, k=k)
243
+
244
  if not retrieved:
245
  return "No documents found for your query.", []
246
 
streamlit_app.py CHANGED
@@ -31,7 +31,7 @@ def initialize_models() -> Tuple[Optional[ChatOpenAI], HuggingFaceEmbeddings]:
31
  if "llm" not in st.session_state:
32
  # Initialize OpenAI model
33
  st.session_state.llm = ChatOpenAI(
34
- model="gpt-4o-mini", # Changed from gpt-4o-mini which appears to be a typo
35
  temperature=0,
36
  timeout=60, # Added reasonable timeout
37
  max_retries=2
@@ -81,7 +81,7 @@ def process_message(
81
  return f"Error processing message: {str(e)}", []
82
 
83
  def display_sources(sources: List) -> None:
84
- """Display sources with minimal output: content preview, source, URL, and image if available."""
85
  if not sources:
86
  st.info("No sources available for this response.")
87
  return
@@ -89,40 +89,56 @@ def display_sources(sources: List) -> None:
89
  st.subheader("Sources")
90
  for doc in sources:
91
  try:
92
- source = doc.metadata.get("source", "Unknown Source")
93
- title = doc.metadata.get("title_info_primary_tsi", "Unknown Title")
 
 
94
 
95
- with st.expander(f"{title}"):
 
 
 
 
96
  # Content preview
97
  if hasattr(doc, 'page_content'):
98
- st.markdown(f"**Content:** {doc.page_content[:100]} ...")
99
 
100
- # Extract URL
101
- doc_url = doc.metadata.get("URL", "").strip()
102
  if not doc_url and source:
103
  doc_url = f"https://www.digitalcommonwealth.org/search/{source}"
104
 
105
  st.markdown(f"**Source ID:** {source}")
 
106
  st.markdown(f"**URL:** {doc_url}")
107
 
108
- # Try to show an image
109
- scraper = DigitalCommonwealthScraper()
110
- images = scraper.extract_images(doc_url)
111
- images = images[:1]
112
-
113
- if images:
114
- output_dir = 'downloaded_images'
115
- if os.path.exists(output_dir):
116
- shutil.rmtree(output_dir)
117
- downloaded_files = scraper.download_images(images)
118
- st.image(downloaded_files, width=400, caption=[
119
- img.get('alt', f'Image') for img in images
120
- ])
 
 
 
 
 
 
 
 
121
  except Exception as e:
122
  logger.warning(f"[display_sources] Error displaying document: {e}")
123
  st.error("Error displaying one of the sources.")
124
 
125
 
 
126
  def main():
127
  st.title("Digital Commonwealth RAG πŸ€–")
128
 
 
31
  if "llm" not in st.session_state:
32
  # Initialize OpenAI model
33
  st.session_state.llm = ChatOpenAI(
34
+ model="gpt-3.5-turbo",
35
  temperature=0,
36
  timeout=60, # Added reasonable timeout
37
  max_retries=2
 
81
  return f"Error processing message: {str(e)}", []
82
 
83
  def display_sources(sources: List) -> None:
84
+ """Display sources with minimal output: content preview, source, URL, and image/audio if available."""
85
  if not sources:
86
  st.info("No sources available for this response.")
87
  return
 
89
  st.subheader("Sources")
90
  for doc in sources:
91
  try:
92
+ metadata = doc.metadata
93
+ source = metadata.get("source", "Unknown Source")
94
+ title = metadata.get("title_info_primary_tsi", "Unknown Title")
95
+ format_type = metadata.get("format", "").lower()
96
 
97
+ is_audio = "audio" in format_type
98
+
99
+ expander_title = f"πŸ”Š {title}" if is_audio else title
100
+
101
+ with st.expander(expander_title):
102
  # Content preview
103
  if hasattr(doc, 'page_content'):
104
+ st.markdown(f"**Content:** {doc.page_content[:300]} ...")
105
 
106
+ # URL building
107
+ doc_url = metadata.get("URL", "").strip()
108
  if not doc_url and source:
109
  doc_url = f"https://www.digitalcommonwealth.org/search/{source}"
110
 
111
  st.markdown(f"**Source ID:** {source}")
112
+ st.markdown(f"**Format:** {format_type if format_type else 'Not specified'}")
113
  st.markdown(f"**URL:** {doc_url}")
114
 
115
+ # πŸ”Š Try to show audio if it's an audio entry and there's a media file
116
+ if is_audio:
117
+ # Try to find a playable media file β€” if metadata has audio URLs
118
+ # For now, just embed a dummy player or placeholder
119
+ st.info("This is an audio entry.")
120
+ # Optionally:
121
+ # st.audio("https://example.com/audio-file.mp3") # replace with real audio URL
122
+ else:
123
+ # πŸ–ΌοΈ Show image if it's not audio
124
+ scraper = DigitalCommonwealthScraper()
125
+ images = scraper.extract_images(doc_url)
126
+ images = images[:1]
127
+
128
+ if images:
129
+ output_dir = 'downloaded_images'
130
+ if os.path.exists(output_dir):
131
+ shutil.rmtree(output_dir)
132
+ downloaded_files = scraper.download_images(images)
133
+ st.image(downloaded_files, width=400, caption=[
134
+ img.get('alt', f'Image') for img in images
135
+ ])
136
  except Exception as e:
137
  logger.warning(f"[display_sources] Error displaying document: {e}")
138
  st.error("Error displaying one of the sources.")
139
 
140
 
141
+
142
  def main():
143
  st.title("Digital Commonwealth RAG πŸ€–")
144