Pavan2k4 commited on
Commit
628936c
·
verified ·
1 Parent(s): 8f4e106

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +325 -0
app.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
+ from langchain_community.document_loaders import PyPDFLoader, WebBaseLoader
5
+ from langchain_community.tools.tavily_search import TavilySearchResults
6
+ from langchain_community.vectorstores import SKLearnVectorStore
7
+ from langchain_openai import ChatOpenAI
8
+ from langchain_huggingface import HuggingFaceEmbeddings
9
+ from langchain_pinecone import PineconeVectorStore
10
+ from langchain.prompts import PromptTemplate
11
+ from langchain_core.output_parsers import StrOutputParser
12
+ from langchain_core.prompts import ChatPromptTemplate
13
+ from pydantic import BaseModel, Field
14
+ from typing import List, TypedDict, Optional
15
+ from langchain.schema import Document
16
+ from langgraph.graph import START, END, StateGraph
17
+ from dotenv import load_dotenv
18
+
19
+ load_dotenv()
20
+
21
+ url = [
22
+ "https://www.investopedia.com/",
23
+ "https://www.fool.com/",
24
+ "https://www.morningstar.com/",
25
+ "https://www.kiplinger.com/",
26
+ "https://www.nerdwallet.com/"
27
+ ]
28
+
29
+ # Initialize Embedding and Vector DB
30
+ embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
31
+
32
+ # Initialize Pinecone connection
33
+ try:
34
+ pc = PineconeVectorStore(
35
+ pinecone_api_key=os.environ.get('PINECONE_KEY'),
36
+ embedding=embedding_model,
37
+ index_name='rag-rubic',
38
+ namespace='vectors_lightmodel'
39
+ )
40
+ retriever = pc.as_retriever(search_kwargs={"k": 10})
41
+ except Exception as e:
42
+ print(f"Pinecone connection error: {e}")
43
+ # Fallback to SKLearn vector store if Pinecone fails
44
+ retriever = None
45
+
46
+ # Initialize the LLM
47
+ llm = ChatOpenAI(
48
+ model='gpt-4o-mini',
49
+ api_key=os.environ.get('OPENAI_KEY'),
50
+ temperature=0.2
51
+ )
52
+
53
+ # Schema for grading documents
54
+ class GradeDocuments(BaseModel):
55
+ binary_score: str = Field(description="Documents are relevant to the question, 'yes' or 'no'")
56
+
57
+ structured_llm_grader = llm.with_structured_output(GradeDocuments)
58
+
59
+ # Define System and Grading prompt
60
+ system = """You are a grader assessing relevance of a retrieved document to a user question.
61
+ If the document contains keyword(s) or semantic meaning related to the question, grade it as relevant.
62
+ Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
63
+
64
+ grade_prompt = ChatPromptTemplate.from_messages([
65
+ ("system", system),
66
+ ("human", "Retrieved document: \n\n {documents} \n\n User question: {question}")
67
+ ])
68
+
69
+ retrieval_grader = grade_prompt | structured_llm_grader
70
+
71
+ # RAG Prompt template
72
+ prompt = PromptTemplate(
73
+ template='''
74
+ You are a Registered Investment Advisor with expertise in Indian financial markets and client relations.
75
+ You must understand what the user is asking about their financial investments and respond to their queries based on the information in the documents only.
76
+
77
+ Use the following documents to answer the question. If you do not know the answer, say you don't know.
78
+
79
+ Query: {question}
80
+ Documents: {context}
81
+ ''',
82
+ input_variables=['question', 'context']
83
+ )
84
+
85
+ rag_chain = prompt | llm | StrOutputParser()
86
+
87
+ # Web search tool for adding data from websites
88
+ web_search_tool = TavilySearchResults(api_key=os.environ.get('TAVILY_API_KEY'), k=5)
89
+
90
+ # Load website data
91
+ try:
92
+ print("Loading web data...")
93
+ docs = []
94
+ for i in url:
95
+ try:
96
+ docs.append(WebBaseLoader(i).load())
97
+ except Exception as e:
98
+ print(f"Error loading {i}: {e}")
99
+
100
+ docs_list = [item for sublist in docs for item in sublist]
101
+
102
+ # Split documents into chunks
103
+ text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
104
+ chunk_size=1000,
105
+ chunk_overlap=100
106
+ )
107
+ doc_splits = text_splitter.split_documents(docs_list)
108
+
109
+ # VectorStore from the web-scraped documents
110
+ vectorstore = SKLearnVectorStore.from_documents(
111
+ documents=doc_splits,
112
+ embedding=embedding_model
113
+ )
114
+ retriever_web = vectorstore.as_retriever(search_kwargs={"k": 5})
115
+ print(f"Loaded {len(doc_splits)} document chunks from web sources")
116
+ except Exception as e:
117
+ print(f"Error in web data processing: {e}")
118
+ # Create a simple retriever that returns empty results if web loading fails
119
+ retriever_web = lambda x: []
120
+
121
+ # Define Graph states and transitions
122
+ class GraphState(TypedDict):
123
+ question: str
124
+ generation: Optional[str]
125
+ need_web_search: Optional[str] # Changed from 'web_search' to 'need_web_search'
126
+ documents: List
127
+
128
+ def retrieve_db(state):
129
+ """Gather data for the query."""
130
+ question = state['question']
131
+ if retriever:
132
+ try:
133
+ results = retriever.invoke(question)
134
+ return {'documents': results, 'question': question}
135
+ except Exception as e:
136
+ print(f"Retriever error: {e}")
137
+
138
+ # If retriever fails or doesn't exist, return empty documents
139
+ return {'documents': [], 'question': question, 'need_web_search': 'yes'}
140
+
141
+ def grade_docs(state):
142
+ """Grades the docs generated by the retriever_db"""
143
+ question = state['question']
144
+ docs = state['documents']
145
+
146
+ if not docs:
147
+ return {"documents": [], 'question': question, 'need_web_search': 'yes'}
148
+
149
+ filtered_data = []
150
+ web_search_needed = "no"
151
+
152
+ try:
153
+ for doc in docs:
154
+ doc_content = doc.page_content if hasattr(doc, 'page_content') else str(doc)
155
+ score = retrieval_grader.invoke({'question': question, 'documents': doc_content})
156
+ grade = score.binary_score
157
+ if grade == 'yes':
158
+ filtered_data.append(doc)
159
+ except Exception as e:
160
+ print(f"Error in document grading: {e}")
161
+ web_search_needed = "yes"
162
+
163
+ # If no relevant documents were found, trigger web search
164
+ if not filtered_data:
165
+ web_search_needed = "yes"
166
+
167
+ return {
168
+ "documents": filtered_data,
169
+ 'question': question,
170
+ 'need_web_search': web_search_needed # Updated key name
171
+ }
172
+
173
+ def decide(state):
174
+ """Decide if the generation should be based on DB or web search DATA"""
175
+ web = state.get('need_web_search', 'no') # Updated key name
176
+ if web == 'yes':
177
+ return 'web_search'
178
+ else:
179
+ return 'generate'
180
+
181
+ def web_search(state):
182
+ """Based on the Grade, will proceed with WebSearch within the given URL's."""
183
+ question = state['question']
184
+ documents = state.get("documents", [])
185
+
186
+ try:
187
+ # First try website-specific retriever
188
+ docs = retriever_web.invoke(question)
189
+ if not docs:
190
+ # If no results, try Tavily search
191
+ search_results = web_search_tool.invoke(question)
192
+ data = "\n".join(result["content"] for result in search_results)
193
+ docs = [Document(page_content=data)]
194
+ except Exception as e:
195
+ print(f"Web search error: {e}")
196
+ # Create a fallback document if search fails
197
+ docs = [Document(page_content="Unable to retrieve additional information.")]
198
+
199
+ # Combine existing documents with new ones
200
+ all_docs = documents + docs
201
+
202
+ return {'documents': all_docs, 'question': question}
203
+
204
+ def generate(state):
205
+ """Generate response based on retrieved documents"""
206
+ documents = state.get('documents', [])
207
+ question = state['question']
208
+
209
+ # Convert documents to text for the context
210
+ if documents:
211
+ try:
212
+ context = "\n\n".join(
213
+ doc.page_content if hasattr(doc, 'page_content') else str(doc)
214
+ for doc in documents
215
+ )
216
+ except Exception as e:
217
+ print(f"Error processing documents: {e}")
218
+ context = "Error retrieving relevant information."
219
+ else:
220
+ context = "No relevant information found."
221
+
222
+ try:
223
+ response = rag_chain.invoke({'context': context, 'question': question})
224
+ except Exception as e:
225
+ print(f"Generation error: {e}")
226
+ response = "I apologize, but I encountered an error while generating a response. Please try asking your question again."
227
+
228
+ return {
229
+ 'documents': documents,
230
+ 'question': question,
231
+ 'generation': response
232
+ }
233
+
234
+ # Compile Workflow
235
+ workflow = StateGraph(GraphState)
236
+ workflow.add_node("retrieve", retrieve_db)
237
+ workflow.add_node("grader", grade_docs)
238
+ workflow.add_node("web_search", web_search) # Now this won't conflict with the state key
239
+ workflow.add_node("generate", generate)
240
+
241
+ workflow.add_edge(START, "retrieve")
242
+ workflow.add_edge("retrieve", "grader")
243
+ workflow.add_conditional_edges(
244
+ "grader",
245
+ decide,
246
+ {
247
+ 'web_search': 'web_search',
248
+ 'generate': 'generate'
249
+ },
250
+ )
251
+ workflow.add_edge("web_search", "generate")
252
+ workflow.add_edge("generate", END)
253
+
254
+ # Compile the graph
255
+ crag = workflow.compile()
256
+
257
+ # Define Gradio Interface with proper chat history management
258
+ def process_query(user_input, history):
259
+ # Initialize history if it's None
260
+ if history is None:
261
+ history = []
262
+
263
+ # Add user input to history
264
+ history.append((user_input, ""))
265
+
266
+ # Process the query
267
+ inputs = {"question": user_input}
268
+ response = ""
269
+
270
+ try:
271
+ # Execute the graph
272
+ result = crag.invoke(inputs)
273
+ if result and 'generation' in result:
274
+ response = result['generation']
275
+ else:
276
+ response = "I couldn't find relevant information to answer your question."
277
+ except Exception as e:
278
+ print(f"Error in crag execution: {e}")
279
+ response = "I encountered an error while processing your request. Please try again."
280
+
281
+ # Update the last response in history
282
+ history[-1] = (user_input, response)
283
+
284
+ return history, ""
285
+
286
+ # Gradio Interface
287
+ with gr.Blocks() as demo:
288
+ gr.Markdown("# 🤖 RAG-Powered Financial Advisor Chatbot")
289
+
290
+ chatbot = gr.Chatbot(
291
+ [],
292
+ elem_id="chatbot",
293
+ bubble_full_width=False,
294
+ height=600,
295
+ avatar_images=(None, "🤖")
296
+ )
297
+
298
+ with gr.Row():
299
+ msg = gr.Textbox(
300
+ placeholder="Ask me anything about Indian financial markets...",
301
+ label="Your question:",
302
+ scale=9
303
+ )
304
+ submit_btn = gr.Button("Send", scale=1)
305
+
306
+ clear_btn = gr.Button("Clear Chat")
307
+
308
+ # Set up event handlers
309
+ submit_click_event = submit_btn.click(
310
+ process_query,
311
+ inputs=[msg, chatbot],
312
+ outputs=[chatbot, msg]
313
+ )
314
+
315
+ msg.submit(
316
+ process_query,
317
+ inputs=[msg, chatbot],
318
+ outputs=[chatbot, msg]
319
+ )
320
+
321
+ clear_btn.click(lambda: [], outputs=[chatbot])
322
+
323
+
324
+ if __name__ == "__main__":
325
+ demo.launch()