DrishtiSharma commited on
Commit
eb06b06
·
verified ·
1 Parent(s): 8bb92b7

Create rag_app.py

Browse files
Files changed (1) hide show
  1. rag_app.py +304 -0
rag_app.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from langchain_groq import ChatGroq
4
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
5
+ from langchain.embeddings import HuggingFaceBgeEmbeddings
6
+ from langchain_community.vectorstores import Chroma
7
+ from langchain.chains import ConversationalRetrievalChain
8
+ from langchain.schema import Document
9
+ import requests
10
+ from bs4 import BeautifulSoup
11
+ from scrapegraphai.graphs import SmartScraperGraph
12
+ import asyncio
13
+ from functools import partial
14
+ import sys
15
+ from crawl4ai import AsyncWebCrawler, CacheMode, CrawlerRunConfig
16
+ from langchain_community.document_loaders import TextLoader
17
+
18
+ import chromadb
19
+ from chromadb.config import Settings
20
+ import os
21
+ chroma_setting = Settings(anonymized_telemetry=False)
22
+ persist_directory = "chroma_db"
23
+ collection_metadata = {"hnsw:space": "cosine"}
24
+ client = chromadb.PersistentClient(path=persist_directory, settings=chroma_setting)
25
+ # Set Windows event loop policy
26
+ if sys.platform == "win32":
27
+ asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy())
28
+
29
+ # Apply nest_asyncio to allow nested event loops
30
+ import nest_asyncio # Import nest_asyncio module for asynchronous operations
31
+ nest_asyncio.apply() # Apply nest_asyncio to resolve any issues with asyncio event loop
32
+
33
+ # Load environment variables
34
+ load_dotenv()
35
+ print(os.getenv("GROQ_API_KEY"))
36
+
37
+ class WebRAG:
38
+ def __init__(self):
39
+ # Initialize Groq
40
+ self.llm = ChatGroq(
41
+ api_key=os.getenv("GROQ_API_KEY"),
42
+ model_name="mixtral-8x7b-32768"
43
+ )
44
+ self.response_llm = ChatGroq(
45
+ api_key=os.getenv("GROQ_API_KEY"),
46
+ model_name="DeepSeek-R1-Distill-Llama-70B",
47
+ temperature=0.6,
48
+ max_tokens=2048,
49
+ )
50
+ # Initialize embeddings
51
+ model_kwargs = {"device": "cpu"}
52
+ encode_kwargs = {"normalize_embeddings": True}
53
+
54
+ self.embeddings = HuggingFaceBgeEmbeddings(
55
+ model_name="BAAI/bge-base-en-v1.5",
56
+ model_kwargs=model_kwargs,
57
+ encode_kwargs=encode_kwargs
58
+ )
59
+
60
+ # Initialize text splitter
61
+ self.text_splitter = RecursiveCharacterTextSplitter(
62
+ chunk_size=1000,
63
+ chunk_overlap=200
64
+ )
65
+
66
+ self.vector_store = Chroma(embedding_function= self.embeddings,
67
+ client = client,
68
+ persist_directory=persist_directory,
69
+ client_settings=chroma_setting,
70
+ )
71
+ # self.qa_chain = None
72
+
73
+ def crawl_webpage_bs4(self, url):
74
+ """Crawl webpage using BeautifulSoup"""
75
+ headers = {
76
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
77
+ }
78
+ response = requests.get(url, headers=headers)
79
+ response.raise_for_status()
80
+
81
+ soup = BeautifulSoup(response.text, 'html.parser')
82
+
83
+ # Remove script and style elements
84
+ for script in soup(["script", "style"]):
85
+ script.decompose()
86
+
87
+ # Get text content from relevant tags
88
+ text_elements = soup.find_all(['p', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'li', 'div'])
89
+ content = ' '.join([elem.get_text(strip=True) for elem in text_elements])
90
+
91
+ # Clean up whitespace
92
+ content = ' '.join(content.split())
93
+ return content
94
+
95
+ # Crawl4ai
96
+ async def crawl_webpage_crawl4ai_async(self, url):
97
+ """Crawl webpage using Crawl4ai asynchronously"""
98
+ try:
99
+ crawler_run_config = CrawlerRunConfig(cache_mode=CacheMode.BYPASS)
100
+ async with AsyncWebCrawler() as crawler:
101
+ result = await crawler.arun(url=url, config=crawler_run_config)
102
+ return result.markdown
103
+ except Exception as e:
104
+ raise Exception(f"Error in Crawl4ai async: {str(e)}")
105
+
106
+ def crawl_webpage_crawl4ai(self, url):
107
+ """Synchronous wrapper for crawl4ai"""
108
+ try:
109
+ loop = asyncio.get_event_loop()
110
+ except RuntimeError:
111
+ loop = asyncio.new_event_loop()
112
+ asyncio.set_event_loop(loop)
113
+
114
+ try:
115
+ return loop.run_until_complete(self.crawl_webpage_crawl4ai_async(url))
116
+ except Exception as e:
117
+ raise Exception(f"Error in Crawl4ai: {str(e)}")
118
+
119
+ def crawl_webpage_scrapegraph(self, url):
120
+ """Crawl webpage using ScrapeGraphAI"""
121
+ try:
122
+ # First try with Groq
123
+ graph_config = {
124
+ "llm": {
125
+ "api_key": os.getenv("GROQ_API_KEY"),
126
+ "model": "groq/mixtral-8x7b-32768",
127
+ },
128
+ "verbose": True,
129
+ "headless": True,
130
+ "disable_async": True # Use synchronous mode
131
+ }
132
+
133
+ scraper = SmartScraperGraph(
134
+ prompt="Extract all the useful textual content from the webpage",
135
+ source=url,
136
+ config=graph_config
137
+ )
138
+
139
+ # Use synchronous run
140
+ result = scraper.run()
141
+ print("Groq scraping successful")
142
+ return str(result)
143
+
144
+ except Exception as e:
145
+ print(f"Groq scraping failed, falling back to Ollama: {str(e)}")
146
+ try:
147
+ # Fallback to Ollama
148
+ graph_config = {
149
+ "llm": {
150
+ "model": "ollama/deepseek-r1:8b",
151
+ "temperature": 0,
152
+ "max_tokens": 2048,
153
+ "format": "json",
154
+ "base_url": "http://localhost:11434",
155
+ },
156
+ "embeddings": {
157
+ "model": "ollama/nomic-embed-text",
158
+ "base_url": "http://localhost:11434",
159
+ },
160
+ "verbose": True,
161
+ "disable_async": True # Use synchronous mode
162
+ }
163
+
164
+ scraper = SmartScraperGraph(
165
+ prompt="Extract all the useful textual content from the webpage",
166
+ source=url,
167
+ config=graph_config
168
+ )
169
+
170
+ result = scraper.run()
171
+ print("Ollama scraping successful")
172
+ return str(result)
173
+
174
+ except Exception as e2:
175
+ raise Exception(f"Both Groq and Ollama scraping failed: {str(e2)}")
176
+
177
+ def crawl_and_process(self, url, scraping_method="beautifulsoup"):
178
+ """Crawl the URL and process the content"""
179
+ try:
180
+ # Validate URL
181
+ if not url.startswith(('http://', 'https://')):
182
+ raise ValueError("Invalid URL. Please include http:// or https://")
183
+
184
+ # Crawl the website using selected method
185
+ if scraping_method == "beautifulsoup":
186
+ content = self.crawl_webpage_bs4(url)
187
+ elif scraping_method == "crawl4ai":
188
+ content = self.crawl_webpage_crawl4ai(url)
189
+ else: # scrapegraph
190
+ content = self.crawl_webpage_scrapegraph(url)
191
+
192
+ if not content:
193
+ raise ValueError("No content found at the specified URL")
194
+
195
+ # Clean the content of any problematic characters
196
+ content = content.encode('utf-8', errors='ignore').decode('utf-8')
197
+
198
+ # Create a temporary file with proper encoding
199
+ import tempfile
200
+ with tempfile.NamedTemporaryFile(mode='w', encoding='utf-8', delete=False, suffix='.txt') as temp_file:
201
+ temp_file.write(content)
202
+ temp_path = temp_file.name
203
+
204
+ try:
205
+ # Load and process the document
206
+ docs = TextLoader(temp_path, encoding='utf-8').load()
207
+ docs = [Document(page_content=doc.page_content, metadata={"source": url}) for doc in docs]
208
+ chunks = self.text_splitter.split_documents(docs)
209
+ print(f"Length of chunks: {len(chunks)}")
210
+ print(f"First chunk: {chunks[0].metadata['source']}")
211
+
212
+ # Check if path exists
213
+ data_exists = False
214
+ existing_urls = []
215
+
216
+ if os.path.exists("chroma_db"):
217
+ # Check if the URL is already in the metadata
218
+ print(f"Checking if URL {url} is already in the metadata")
219
+ try:
220
+ self.vectorstore = Chroma(
221
+ embedding_function=self.embeddings,
222
+ client=client,
223
+ persist_directory=persist_directory
224
+ )
225
+ entities = self.vector_store.get(include=["metadatas"])
226
+ print(f"Entities: {len(entities['metadatas'])}")
227
+ if len(entities['metadatas']) > 0:
228
+ for entry in entities['metadatas']:
229
+ #print(f"Entry: {entry}")
230
+ existing_urls.append(entry["source"])
231
+ except Exception as e:
232
+ print(f"Error checking existing URLs: {str(e)}")
233
+ print(f"Existing URLs: {set(existing_urls)}")
234
+ if url in set(existing_urls):
235
+ data_exists = True
236
+ print(f"URL {url} already exists in the vector store")
237
+ # Load the existing vector store
238
+ else:
239
+ # Add new documents to the vector store
240
+ MAX_BATCH_SIZE = 100
241
+ for i in range(0,len(chunks),MAX_BATCH_SIZE):
242
+ #print(f"start of processing: {i}")
243
+ i_end = min(len(chunks),i+MAX_BATCH_SIZE)
244
+ #print(f"end of processing: {i_end}")
245
+ batch = chunks[i:i_end]
246
+ #
247
+ self.vectorstore.add_documents(batch)
248
+ print(f"vectors for batch {i} to {i_end} stored successfully...")
249
+
250
+
251
+ # Create QA chain
252
+ self.qa_chain = ConversationalRetrievalChain.from_llm(
253
+ llm=self.response_llm,
254
+ retriever=self.vector_store.as_retriever(search_type="similarity",
255
+ search_kwargs={"k": 5,"filter":{"source": url}}),
256
+ return_source_documents=True
257
+ )
258
+
259
+ finally:
260
+ # Clean up the temporary file
261
+ try:
262
+ os.unlink(temp_path)
263
+ except:
264
+ pass
265
+
266
+ except Exception as e:
267
+ raise Exception(f"Error processing URL: {str(e)}")
268
+
269
+ def ask_question(self, question, chat_history=[]):
270
+ """Ask a question about the processed content"""
271
+ try:
272
+ if not self.qa_chain:
273
+ raise ValueError("Please crawl and process a URL first")
274
+
275
+ response = self.qa_chain.invoke({"question": question, "chat_history": chat_history[:4000]})
276
+ print(f"Response: {response}")
277
+ final_answer = response["answer"].split("</think>\n\n")[-1]
278
+ return final_answer
279
+ except Exception as e:
280
+ raise Exception(f"Error generating response: {str(e)}")
281
+
282
+ def main():
283
+ # Initialize the RAG system
284
+ rag = WebRAG()
285
+
286
+ # Get URL from user
287
+ url = input("Enter the URL to process: ")
288
+ print("Processing URL... This may take a moment.")
289
+ scraping_method = input("Choose scraping method (beautifulsoup or scrapegraph or crawl4ai): ")
290
+ rag.crawl_and_process(url, scraping_method)
291
+
292
+ # Interactive Q&A loop
293
+ chat_history = []
294
+ while True:
295
+ question = input("\nEnter your question (or 'quit' to exit): ")
296
+ if question.lower() == 'quit':
297
+ break
298
+
299
+ answer = rag.ask_question(question, chat_history)
300
+ print("\nAnswer:", answer)
301
+ chat_history.append((question, answer))
302
+
303
+ if __name__ == "__main__":
304
+ main()