rithvik213 commited on
Commit
d5c23d7
·
1 Parent(s): d1b836e

added files to run app

Browse files
Files changed (2) hide show
  1. RAG.py +469 -0
  2. streamlit_app.py +214 -0
RAG.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import getpass
2
+ import os
3
+ import time
4
+ from pinecone import Pinecone, ServerlessSpec
5
+ from langchain_pinecone import PineconeVectorStore
6
+ from langchain_huggingface import HuggingFaceEmbeddings
7
+ from dotenv import load_dotenv
8
+ from langchain_core.prompts import PromptTemplate
9
+ from langchain_openai import ChatOpenAI
10
+ import re
11
+ from langchain_core.documents import Document
12
+ from langchain_community.retrievers import BM25Retriever
13
+ import requests
14
+ import psycopg2
15
+ from collections import defaultdict
16
+ from typing import Dict, Any, Optional, List, Tuple
17
+ import json
18
+ import logging
19
+
20
+ def retrieve(query: str,vectorstore:PineconeVectorStore, k: int = 1000) -> Tuple[List[Document], List[float]]:
21
+ start = time.time()
22
+ results = vectorstore.similarity_search_with_score(
23
+ query,
24
+ k=k,
25
+ )
26
+ documents = []
27
+ scores = []
28
+ for res, score in results:
29
+ # check to make sure response isnt too long for context window of 4o-mini
30
+ if len(res.page_content) > 4000:
31
+ res.page_content = res.page_content[:4000]
32
+ documents.append(res)
33
+ scores.append(score)
34
+ logging.info(f"Finished Retrieval: {time.time() - start}")
35
+ return documents, scores
36
+
37
+ def safe_get_json(url: str) -> Optional[Dict]:
38
+ """Safely fetch and parse JSON from a URL."""
39
+ print("Fetching JSON")
40
+ try:
41
+ response = requests.get(url, timeout=10)
42
+ response.raise_for_status()
43
+ return response.json()
44
+ except Exception as e:
45
+ logging.error(f"Error fetching from {url}: {str(e)}")
46
+ return None
47
+
48
+ def extract_text_from_json(json_data: Dict) -> str:
49
+ """Extract text content from JSON response."""
50
+ if not json_data:
51
+ return ""
52
+
53
+ text_parts = []
54
+
55
+ # Handle direct text fields
56
+ text_fields = ["title_info_primary_tsi","abstract_tsi","subject_geographic_sim","genre_basic_ssim","genre_specific_ssim","date_tsim"]
57
+ for field in text_fields:
58
+ if field in json_data['data']['attributes'] and json_data['data']['attributes'][field]:
59
+ # print(json_data[field])
60
+ text_parts.append(str(json_data['data']['attributes'][field]))
61
+
62
+ return " ".join(text_parts) if text_parts else "No content available"
63
+
64
+ def rephrase_and_expand_query(query: str, llm: Any) -> str:
65
+
66
+ # Use LLM to rewrite and expand a query for better alignment with archive metadata.
67
+ prompt_template = PromptTemplate.from_template(
68
+ """
69
+ You are a professional librarian skilled at historical research.
70
+ Your task is to improve and expand the following search query to better match metadata in a historical archive.
71
+
72
+ - First, rewrite the query to improve clarity and fit how librarians would search.
73
+ - Second, expand the query by adding related terms (synonyms, related concepts, historical terminology, etc.).
74
+
75
+ Return your output strictly in this format (no extra explanation):
76
+ <IMPROVED_QUERY>your improved query here</IMPROVED_QUERY>
77
+ <EXPANDED_QUERY>your expanded query here</EXPANDED_QUERY>
78
+
79
+ Original Query: {query}
80
+ """
81
+ )
82
+
83
+ prompt = prompt_template.invoke({"query": query})
84
+ response = llm.invoke(prompt)
85
+
86
+ # Extract just the improved and expanded queries
87
+ improved_match = re.search(r"<IMPROVED_QUERY>(.*?)</IMPROVED_QUERY>", response.content, re.DOTALL)
88
+ expanded_match = re.search(r"<EXPANDED_QUERY>(.*?)</EXPANDED_QUERY>", response.content, re.DOTALL)
89
+
90
+ improved_query = improved_match.group(1).strip() if improved_match else query
91
+ expanded_query = expanded_match.group(1).strip() if expanded_match else ""
92
+
93
+ final_query = f"{improved_query} {expanded_query}".strip()
94
+
95
+ logging.info(f"Original Query: {query}")
96
+ logging.info(f"Improved Query: {improved_query}")
97
+ logging.info(f"Expanded Query: {expanded_query}")
98
+ logging.info(f"Final Query for Retrieval: {final_query}")
99
+
100
+ return final_query
101
+
102
+
103
+
104
+ weights = {
105
+ "title_info_primary_tsi": 1.5, # Titles should be prioritized
106
+ "name_role_tsim": 1.4, # Author/role should be highly weighted
107
+ "date_tsim": 1.3, # Date should be considered
108
+ "abstract_tsi": 1.0, # Abstracts are important but less so
109
+ "note_tsim": 0.8,
110
+ "subject_geographic_sim": 0.5,
111
+ "genre_basic_ssim": 0.5,
112
+ "genre_specific_ssim": 0.5,
113
+ }
114
+
115
+ def get_metadata(document_ids: List[str]) -> Dict[str, Dict]:
116
+ """ Fetch metadata from either PostgreSQL or the Commonwealth API, based on config """
117
+
118
+ if USE_DB_FOR_METADATA:
119
+ return get_metadata_from_db(document_ids)
120
+ else:
121
+ return get_metadata_from_api(document_ids)
122
+
123
+ def get_metadata_from_db(document_ids: List[str]) -> Dict[str, Dict]:
124
+ """ Fetch metadata from PostgreSQL """
125
+ conn = psycopg2.connect(
126
+ host="127.0.0.1",
127
+ port="5435",
128
+ dbname="bpl_metadata",
129
+ user="postgres",
130
+ password="MNOF.MzLDjcgzAXu" # Replace with real one or load with dotenv
131
+ )
132
+ cur = conn.cursor()
133
+
134
+ sql_query = """
135
+ SELECT id, title, abstract, subjects, institution, metadata_url, image_url
136
+ FROM metadata
137
+ WHERE id = ANY(%s);
138
+ """
139
+ cur.execute(sql_query, (document_ids,))
140
+ results = cur.fetchall()
141
+ cur.close()
142
+ conn.close()
143
+
144
+ # Convert results to a dictionary
145
+ return {
146
+ row[0]: {
147
+ "title": row[1],
148
+ "abstract": row[2],
149
+ "subjects": row[3],
150
+ "institution": row[4],
151
+ "metadata_url": row[5],
152
+ "image_url": row[6],
153
+ }
154
+ for row in results
155
+ }
156
+
157
+ def get_metadata_from_api(document_ids: List[str]) -> Dict[str, Dict]:
158
+ """ Fetch metadata from the Commonwealth API """
159
+ metadata_dict = {}
160
+ for doc_id in document_ids:
161
+ url = f"https://www.digitalcommonwealth.org/search/{doc_id}.json"
162
+ json_data = safe_get_json(url)
163
+ if json_data:
164
+ metadata_dict[doc_id] = extract_text_from_json(json_data)
165
+ return metadata_dict
166
+
167
+
168
+
169
+ """
170
+ def rerank(documents: List[Document], query: str) -> List[Document]:
171
+ \"\"\"Ingest more metadata. Rerank documents using BM25\"\"\"
172
+ start = time.time()
173
+ if not documents:
174
+ return []
175
+
176
+ full_docs = []
177
+ seen_sources = set()
178
+ meta_start = time.time()
179
+ for doc in documents:
180
+ source = doc.metadata.get('source')
181
+ if not source or source in seen_sources:
182
+ continue # Skip duplicate sources
183
+ seen_sources.add(source)
184
+
185
+ url = f"https://www.digitalcommonwealth.org/search/{source}"
186
+ json_data = safe_get_json(f"{url}.json")
187
+
188
+ if json_data:
189
+ text_content = extract_text_from_json(json_data)
190
+ if text_content: # Only add documents with actual content
191
+ full_docs.append(Document(page_content=text_content, metadata={"source": source, "field": doc.metadata.get("field", ""), "URL": url}))
192
+
193
+ logging.info(f"Took {time.time()-meta_start} seconds to retrieve all metadata")
194
+ if not full_docs:
195
+ return []
196
+
197
+ # Create BM25 retriever with the processed documents
198
+ bm25 = BM25Retriever.from_documents(full_docs, k=min(10, len(full_docs)))
199
+ bm25_ranked_docs = bm25.invoke(query)
200
+
201
+ ranked_docs = []
202
+ for doc in bm25_ranked_docs:
203
+ bm25_score = 1.0
204
+
205
+ # Compute metadata multiplier
206
+ metadata_multiplier = 1.0
207
+ for field, weight in weights.items():
208
+ if field in doc.metadata and doc.metadata[field]:
209
+ metadata_multiplier += weight
210
+
211
+ # Compute final score: BM25 weight * Metadata multiplier
212
+ final_score = bm25_score * metadata_multiplier
213
+ ranked_docs.append((doc, final_score))
214
+
215
+ # Sort by final score
216
+ ranked_docs.sort(key=lambda x: x[1], reverse=True)
217
+
218
+ logging.info(f"Finished reranking: {time.time()-start}")
219
+ return [doc for doc, _ in ranked_docs]
220
+ """
221
+
222
+ '''
223
+ def rerank(documents: List[Document], query: str) -> List[Document]:
224
+ """Retrieve metadata from the database and rerank using BM25"""
225
+ start = time.time()
226
+ if not documents:
227
+ return []
228
+
229
+ document_ids = [doc.metadata.get('source') for doc in documents if doc.metadata.get('source')]
230
+
231
+ # Fetch metadata from PostgreSQL
232
+ metadata_dict = get_metadata_from_db(document_ids)
233
+
234
+ full_docs = []
235
+ for doc in documents:
236
+ doc_id = doc.metadata.get('source')
237
+ metadata = metadata_dict.get(doc_id, {})
238
+
239
+ if metadata:
240
+ text_content = " ".join([
241
+ metadata.get("title", ""),
242
+ metadata.get("abstract", ""),
243
+ " ".join(metadata.get("subjects", [])),
244
+ metadata.get("institution", "")
245
+ ]).strip()
246
+
247
+
248
+ if text_content:
249
+ full_docs.append(Document(page_content=text_content, metadata={
250
+ "source": doc_id,
251
+ "URL": metadata.get("metadata_url", ""),
252
+ "image_url": metadata.get("image_url", "")
253
+ }))
254
+
255
+ logging.info(f"Took {time.time()-start} seconds to retrieve all metadata from PostgreSQL")
256
+
257
+ if not full_docs:
258
+ return []
259
+
260
+ # Rerank using BM25
261
+ bm25 = BM25Retriever.from_documents(full_docs, k=min(10, len(full_docs)))
262
+ bm25_ranked_docs = bm25.invoke(query)
263
+
264
+ ranked_docs = []
265
+ for doc in bm25_ranked_docs:
266
+ bm25_score = 1.0
267
+
268
+ # Compute metadata multiplier
269
+ metadata_multiplier = 1.0
270
+ for field, weight in weights.items():
271
+ if field in doc.metadata and doc.metadata[field]:
272
+ metadata_multiplier += weight
273
+
274
+ # Compute final score: BM25 weight * Metadata multiplier
275
+ final_score = bm25_score * metadata_multiplier
276
+ ranked_docs.append((doc, final_score))
277
+
278
+ # Sort by final score
279
+ ranked_docs.sort(key=lambda x: x[1], reverse=True)
280
+
281
+ logging.info(f"Finished reranking: {time.time()-start}")
282
+ return [doc for doc, _ in ranked_docs]
283
+ '''
284
+
285
+ def rerank(documents: List[Document], query: str) -> List[Document]:
286
+ """Rerank using BM25 and enhance scores using document metadata."""
287
+ start = time.time()
288
+
289
+ if not documents:
290
+ return []
291
+
292
+ # Group document chunks by source_id
293
+ grouped = defaultdict(list)
294
+ for doc in documents:
295
+ source_id = doc.metadata.get("source")
296
+ if source_id:
297
+ grouped[source_id].append(doc)
298
+
299
+ full_docs = []
300
+ for source_id, chunks in grouped.items():
301
+ combined_text = " ".join([chunk.page_content for chunk in chunks if chunk.page_content])
302
+ representative_metadata = chunks[0].metadata or {}
303
+
304
+ #logging.debug(f"Metadata for doc {source_id}: {representative_metadata}")
305
+
306
+ if combined_text.strip():
307
+ full_docs.append(Document(
308
+ page_content=combined_text.strip(),
309
+ metadata={
310
+ "source": source_id,
311
+ "URL": representative_metadata.get("metadata_url", ""),
312
+ "image_url": representative_metadata.get("image_url", ""),
313
+ **representative_metadata # preserve all original fields
314
+ }
315
+ ))
316
+
317
+ logging.info(f"Built {len(full_docs)} documents for reranking in {time.time() - start:.2f} seconds.")
318
+
319
+ if not full_docs:
320
+ return []
321
+
322
+ # BM25 reranking
323
+ bm25 = BM25Retriever.from_documents(full_docs, k=min(10, len(full_docs)))
324
+ bm25_ranked_docs = bm25.invoke(query)
325
+
326
+ # Score enhancement using metadata weights
327
+ ranked_docs = []
328
+ for doc in bm25_ranked_docs:
329
+ bm25_score = 1.0 # BM25 returns sorted, so base score is 1
330
+ metadata_multiplier = 1.0
331
+ for field, weight in weights.items():
332
+ if field in doc.metadata and doc.metadata[field]:
333
+ metadata_multiplier += weight
334
+ final_score = bm25_score * metadata_multiplier
335
+ ranked_docs.append((doc, final_score))
336
+
337
+ # Sort by enhanced score
338
+ ranked_docs.sort(key=lambda x: x[1], reverse=True)
339
+ logging.info(f"Finished reranking in {time.time() - start:.2f} seconds")
340
+
341
+ return [doc for doc, _ in ranked_docs]
342
+
343
+
344
+
345
+ def parse_xml_and_query(query:str,xml_string:str) -> str:
346
+ """parse xml and return rephrased query"""
347
+ if not xml_string:
348
+ return "No response generated."
349
+
350
+ pattern = r"<(\w+)>(.*?)</\1>"
351
+ matches = re.findall(pattern, xml_string, re.DOTALL)
352
+ parsed_response = dict(matches)
353
+ if parsed_response.get('VALID') == 'NO':
354
+ return query
355
+ return parsed_response.get('STATEMENT', query)
356
+
357
+
358
+ def parse_xml_and_check(xml_string: str) -> str:
359
+ """Parse XML-style tags and handle validation."""
360
+ if not xml_string:
361
+ return "No response generated."
362
+
363
+ pattern = r"<(\w+)>(.*?)</\1>"
364
+ matches = re.findall(pattern, xml_string, re.DOTALL)
365
+ parsed_response = dict(matches)
366
+
367
+ if parsed_response.get('VALID') == 'NO':
368
+ return "Sorry, I was unable to find any documents for your query.\n\n Here are some documents I found that might be relevant."
369
+
370
+ return parsed_response.get('RESPONSE', "No response found in the output")
371
+
372
+ def RAG(llm: Any, query: str,vectorstore:PineconeVectorStore, top: int = 10, k: int = 100) -> Tuple[str, List[Document]]:
373
+ """Main RAG function with improved error handling and validation."""
374
+ start = time.time()
375
+ try:
376
+
377
+ # Query alignment is commented our, however I have decided to leave it in for potential future use.
378
+
379
+ # Retrieve initial documents using rephrased query -- not working as intended currently, maybe would be better for data with more words.
380
+ # query_template = PromptTemplate.from_template(
381
+ # """
382
+ # Your job is to think about a query and then generate a statement that only includes information from the query that would answer the query.
383
+ # You will be provided with a query in <QUERY></QUERY> tags.
384
+ # Then you will think about what kind of information the query is looking for between <REASONING></REASONING> tags.
385
+ # Then, based on the reasoning, you will generate a sample response to the query that only includes information from the query between <STATEMENT></STATEMENT> tags.
386
+ # Afterwards, you will determine and reason about whether or not the statement you generated only includes information from the original query and would answer the query between <DETERMINATION></DETERMINATION> tags.
387
+ # Finally, you will return a YES, or NO response between <VALID></VALID> tags based on whether or not you determined the statment to be valid.
388
+ # Let me provide you with an exmaple:
389
+
390
+ # <QUERY>I would really like to learn more about Bermudan geography<QUERY>
391
+
392
+ # <REASONING>This query is interested in geograph as it relates to Bermuda. Some things they might be interested in are Bermudan climate, towns, cities, and geography</REASONING>
393
+
394
+ # <STATEMENT>Bermuda's Climate is [blank]. Some of Bermuda's cities and towns are [blank]. Other points of interested about Bermuda's geography are [blank].</STATEMENT>
395
+
396
+ # <DETERMINATION>The query originally only mentions bermuda and geography. The answers do not provide any false information, instead replacing meaningful responses with a placeholder [blank]. If it had hallucinated, it would not be valid. Because the statements do not hallucinate anything, this is a valid statement.</DETERMINATION>
397
+
398
+ # <VALID>YES</VALID>
399
+
400
+ # Now it's your turn! Remember not to hallucinate:
401
+
402
+ # <QUERY>{query}</QUERY>
403
+ # """
404
+ # )
405
+ # query_prompt = query_template.invoke({"query":query})
406
+ # query_response = llm.invoke(query_prompt)
407
+ # new_query = parse_xml_and_query(query=query,xml_string=query_response.content)
408
+
409
+ #logging.info(f"\n---\nQUERY: {query}")
410
+
411
+ #new query rephrasing
412
+ #query = rephrase_and_expand_query(query, llm)
413
+ #logging.info(f"\n---\nRephrased QUERY: {query}")
414
+
415
+ retrieved, _ = retrieve(query=query, vectorstore=vectorstore, k=k)
416
+ if not retrieved:
417
+ return "No documents found for your query.", []
418
+
419
+ # Rerank documents
420
+ reranked = rerank(documents=retrieved, query=query)
421
+ logging.info(f"RERANKED LENGTH: {len(reranked)}")
422
+ if not reranked:
423
+ return "Unable to process the retrieved documents.", []
424
+
425
+ # Prepare context from reranked documents
426
+ context = "\n\n".join(doc.page_content for doc in reranked[:top] if doc.page_content)
427
+ if not context.strip():
428
+ return "No relevant content found in the documents.", []
429
+ # change for the sake of another commit
430
+ # Prepare prompt
431
+ answer_template = PromptTemplate.from_template(
432
+ """Pretend you are a professional librarian. Please Summarize The Following Context as though you had retrieved it for a patron:
433
+ Some of the retrieved results may include image descriptions, captions, or references to photos, rather than the images themselves.
434
+ Assume that content describing or captioning an image, or mentioning a place/person clearly, is valid and relevant — even if the actual image isn't embedded.
435
+ Context:{context}
436
+ Make sure to answer in the following format
437
+ First, reason about the answer between <REASONING></REASONING> headers,
438
+ based on the context determine if there is sufficient material for answering the exact question,
439
+ return either <VALID>YES</VALID> or <VALID>NO</VALID>
440
+ then return a response between <RESPONSE></RESPONSE> headers:
441
+ Here is an example
442
+ <EXAMPLE>
443
+ <QUERY>Are pineapples a good fuel for cars?</QUERY>
444
+ <CONTEXT>Cars use gasoline for fuel. Some cars use electricity for fuel.Tesla stock has increased by 10 percent over the last quarter.</CONTEXT>
445
+ <REASONING>Based on the context pineapples have not been explored as a fuel for cars. The context discusses gasoline, electricity, and tesla stock, therefore it is not relevant to the query about pineapples for fuel</REASONING>
446
+ <VALID>NO</VALID>
447
+ <RESPONSE>Pineapples are not a good fuel for cars, however with further research they might be</RESPONSE>
448
+ </EXAMPLE>
449
+ Now it's your turn
450
+ <QUERY>
451
+ {query}
452
+ </QUERY>"""
453
+ )
454
+
455
+ # Generate response
456
+ ans_prompt = answer_template.invoke({"context": context, "query": query})
457
+ response = llm.invoke(ans_prompt)
458
+
459
+ # Parse and return response
460
+ logging.debug(f"RAW LLM RESPONSE:\n{response.content}")
461
+ parsed = parse_xml_and_check(response.content)
462
+ logging.debug(f"PARSED FINAL RESPONSE: {parsed}")
463
+ #logging.info(f"RESPONSE: {parsed}\nRETRIEVED: {reranked}")
464
+ logging.info(f"RAG Finished: {time.time()-start}\n---\n")
465
+ return parsed, reranked
466
+
467
+ except Exception as e:
468
+ logging.error(f"Error in RAG function: {str(e)}")
469
+ return f"An error occurred while processing your query: {str(e)}", []
streamlit_app.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ from typing import List, Tuple, Optional
4
+ from pinecone import Pinecone
5
+ from langchain_pinecone import PineconeVectorStore
6
+ from langchain_huggingface import HuggingFaceEmbeddings
7
+ from langchain_openai import ChatOpenAI
8
+ from langchain_core.prompts import PromptTemplate
9
+ from dotenv import load_dotenv
10
+ from RAG import RAG
11
+ import logging
12
+ from image_scraper import DigitalCommonwealthScraper
13
+ import shutil
14
+
15
+ # Configure logging
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ # Page configuration
20
+ st.set_page_config(
21
+ page_title="Boston Public Library Chatbot",
22
+ page_icon="🤖",
23
+ layout="wide"
24
+ )
25
+
26
+ def initialize_models() -> Tuple[Optional[ChatOpenAI], HuggingFaceEmbeddings]:
27
+ """Initialize the language model and embeddings."""
28
+ try:
29
+ load_dotenv()
30
+
31
+ if "llm" not in st.session_state:
32
+ # Initialize OpenAI model
33
+ st.session_state.llm = ChatOpenAI(
34
+ model="gpt-4o-mini", # Changed from gpt-4o-mini which appears to be a typo
35
+ temperature=0,
36
+ timeout=60, # Added reasonable timeout
37
+ max_retries=2
38
+ )
39
+
40
+ if "embeddings" not in st.session_state:
41
+ # Initialize embeddings
42
+ st.session_state.embeddings = HuggingFaceEmbeddings(
43
+ model_name="sentence-transformers/all-mpnet-base-v2"
44
+ #model_name="sentence-transformers/all-MiniLM-L6-v2"
45
+ )
46
+
47
+ if "pinecone" not in st.session_state:
48
+ pinecone_api_key = os.getenv("PINECONE_API_KEY")
49
+ INDEX_NAME = 'bpl-test'
50
+ #initialize vectorstore
51
+ pc = Pinecone(api_key=pinecone_api_key)
52
+
53
+ index = pc.Index(INDEX_NAME)
54
+ st.session_state.pinecone = PineconeVectorStore(index=index, embedding=st.session_state.embeddings)
55
+
56
+ if "vectorstore" not in st.session_state:
57
+ #st.session_state.vectorstore = CloudSQLVectorStore(embedding=st.session_state.embeddings)
58
+ st.session_state.vectorstore = st.session_state.pinecone
59
+
60
+ except Exception as e:
61
+ logger.error(f"Error initializing models: {str(e)}")
62
+ st.error(f"Failed to initialize models: {str(e)}")
63
+ return None, None
64
+
65
+ def process_message(
66
+ query: str,
67
+ llm: ChatOpenAI,
68
+ vectorstore: PineconeVectorStore,
69
+
70
+ ) -> Tuple[str, List]:
71
+ """Process the user message using the RAG system."""
72
+ try:
73
+ response, sources = RAG(
74
+ query=query,
75
+ llm=llm,
76
+ vectorstore=vectorstore,
77
+ )
78
+ return response, sources
79
+ except Exception as e:
80
+ logger.error(f"Error in process_message: {str(e)}")
81
+ return f"Error processing message: {str(e)}", []
82
+
83
+ def display_sources(sources: List) -> None:
84
+ """Display sources with minimal output: content preview, source, URL, and image if available."""
85
+ if not sources:
86
+ st.info("No sources available for this response.")
87
+ return
88
+
89
+ st.subheader("Sources")
90
+ for doc in sources:
91
+ try:
92
+ source = doc.metadata.get("source", "Unknown Source")
93
+ title = doc.metadata.get("title_info_primary_tsi", "Unknown Title")
94
+
95
+ with st.expander(f"{title}"):
96
+ # Content preview
97
+ if hasattr(doc, 'page_content'):
98
+ st.markdown(f"**Content:** {doc.page_content[:100]} ...")
99
+
100
+ # Extract URL
101
+ doc_url = doc.metadata.get("URL", "").strip()
102
+ if not doc_url and source:
103
+ doc_url = f"https://www.digitalcommonwealth.org/search/{source}"
104
+
105
+ st.markdown(f"**Source ID:** {source}")
106
+ st.markdown(f"**URL:** {doc_url}")
107
+
108
+ # Try to show an image
109
+ scraper = DigitalCommonwealthScraper()
110
+ images = scraper.extract_images(doc_url)
111
+ images = images[:1]
112
+
113
+ if images:
114
+ output_dir = 'downloaded_images'
115
+ if os.path.exists(output_dir):
116
+ shutil.rmtree(output_dir)
117
+ downloaded_files = scraper.download_images(images)
118
+ st.image(downloaded_files, width=400, caption=[
119
+ img.get('alt', f'Image') for img in images
120
+ ])
121
+ except Exception as e:
122
+ logger.warning(f"[display_sources] Error displaying document: {e}")
123
+ st.error("Error displaying one of the sources.")
124
+
125
+
126
+ def main():
127
+ st.title("Digital Commonwealth RAG 🤖")
128
+
129
+ INDEX_NAME = 'bpl-rag'
130
+
131
+ # Initialize session state
132
+ if "messages" not in st.session_state:
133
+ st.session_state.messages = []
134
+
135
+ if "show_settings" not in st.session_state:
136
+ st.session_state.show_settings = False
137
+
138
+ if "num_sources" not in st.session_state:
139
+ st.session_state.num_sources = 10
140
+
141
+
142
+ initialize_models()
143
+
144
+ # 🔵 Settings button
145
+ open_settings = st.button("⚙️ Settings")
146
+
147
+ if open_settings:
148
+ st.session_state.show_settings = True
149
+
150
+ if st.session_state.show_settings:
151
+ with st.container():
152
+ st.markdown("---")
153
+ st.markdown("### ⚙️ Settings")
154
+
155
+ num_sources = st.number_input(
156
+ "Number of Sources to Display",
157
+ min_value=1,
158
+ max_value=100,
159
+ value=st.session_state.num_sources,
160
+ step=1,
161
+ )
162
+ st.session_state.num_sources = num_sources
163
+
164
+ close_settings = st.button("❌ Close Settings")
165
+ if close_settings:
166
+ st.session_state.show_settings = False
167
+ st.markdown("---")
168
+
169
+ # Show chat history
170
+ for message in st.session_state.messages:
171
+ with st.chat_message(message["role"]):
172
+ st.markdown(message["content"])
173
+
174
+ # ⬇️ CHAT INPUT BOX always stuck to bottom
175
+ user_input = st.chat_input("Type your question here...")
176
+
177
+ if user_input:
178
+ with st.chat_message("user"):
179
+ st.markdown(user_input)
180
+ st.session_state.messages.append({"role": "user", "content": user_input})
181
+
182
+ with st.chat_message("assistant"):
183
+ with st.spinner("Thinking... Please be patient..."):
184
+ response, sources = process_message(
185
+ query=user_input,
186
+ llm=st.session_state.llm,
187
+ vectorstore=st.session_state.vectorstore
188
+ )
189
+
190
+ if isinstance(response, str):
191
+ st.markdown(response)
192
+ st.session_state.messages.append({
193
+ "role": "assistant",
194
+ "content": response
195
+ })
196
+
197
+ display_sources(sources[:int(st.session_state.num_sources)])
198
+ else:
199
+ st.error("Received an invalid response format")
200
+
201
+ # Footer (optional, will be above chat input)
202
+ st.markdown("---")
203
+ st.markdown(
204
+ "Built with Langchain + Streamlit + Pinecone",
205
+ help="Natural Language Querying for Digital Commonwealth"
206
+ )
207
+ st.markdown(
208
+ "The Digital Commonwealth site provides access to photographs, manuscripts, books, "
209
+ "audio recordings, and other materials of historical interest that have been digitized "
210
+ "and made available by members of Digital Commonwealth."
211
+ )
212
+
213
+ if __name__ == "__main__":
214
+ main()