Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.document_loaders import PyPDFLoader, WebBaseLoader | |
from langchain_community.tools.tavily_search import TavilySearchResults | |
from langchain_community.vectorstores import SKLearnVectorStore | |
from langchain_openai import ChatOpenAI | |
from langchain_huggingface import HuggingFaceEmbeddings | |
from langchain_pinecone import PineconeVectorStore | |
from langchain.prompts import PromptTemplate | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.prompts import ChatPromptTemplate | |
from pydantic import BaseModel, Field | |
from typing import List, TypedDict, Optional | |
from langchain.schema import Document | |
from langgraph.graph import START, END, StateGraph | |
from dotenv import load_dotenv | |
load_dotenv() | |
url = [ | |
"https://www.investopedia.com/", | |
"https://www.fool.com/", | |
"https://www.morningstar.com/", | |
"https://www.kiplinger.com/", | |
"https://www.nerdwallet.com/" | |
] | |
# Initialize Embedding and Vector DB | |
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
# Initialize Pinecone connection | |
try: | |
pc = PineconeVectorStore( | |
pinecone_api_key=os.environ.get('PINCE_CONE_LIGHT'), | |
embedding=embedding_model, | |
index_name='rag-rubic', | |
namespace='vectors_lightmodel' | |
) | |
retriever = pc.as_retriever(search_kwargs={"k": 10}) | |
except Exception as e: | |
print(f"Pinecone connection error: {e}") | |
# Fallback to SKLearn vector store if Pinecone fails | |
retriever = None | |
# Initialize the LLM | |
llm = ChatOpenAI( | |
model='gpt-4o-mini', | |
api_key=os.environ.get('OPEN_AI_KEY'), | |
temperature=0.2 | |
) | |
# Schema for grading documents | |
class GradeDocuments(BaseModel): | |
binary_score: str = Field(description="Documents are relevant to the question, 'yes' or 'no'") | |
structured_llm_grader = llm.with_structured_output(GradeDocuments) | |
# Define System and Grading prompt | |
system = """You are a grader assessing relevance of a retrieved document to a user question. | |
If the document contains keyword(s) or semantic meaning related to the question, grade it as relevant. | |
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question.""" | |
grade_prompt = ChatPromptTemplate.from_messages([ | |
("system", system), | |
("human", "Retrieved document: \n\n {documents} \n\n User question: {question}") | |
]) | |
retrieval_grader = grade_prompt | structured_llm_grader | |
# RAG Prompt template | |
prompt = PromptTemplate( | |
template=''' | |
You are a Registered Investment Advisor with expertise in Indian financial markets and client relations. | |
You must understand what the user is asking about their financial investments and respond to their queries based on the information in the documents only. | |
Use the following documents to answer the question. If you do not know the answer, say you don't know. | |
Query: {question} | |
Documents: {context} | |
''', | |
input_variables=['question', 'context'] | |
) | |
rag_chain = prompt | llm | StrOutputParser() | |
# Web search tool for adding data from websites | |
web_search_tool = TavilySearchResults(api_key=os.environ.get('TAVILY_API_KEY'), k=10) | |
# Define Graph states and transitions | |
class GraphState(TypedDict): | |
question: str | |
generation: Optional[str] | |
need_web_search: Optional[str] # Changed from 'web_search' to 'need_web_search' | |
documents: List | |
def retrieve_db(state): | |
"""Gather data for the query.""" | |
question = state['question'] | |
if retriever: | |
try: | |
results = retriever.invoke(question) | |
return {'documents': results, 'question': question} | |
except Exception as e: | |
print(f"Retriever error: {e}") | |
# If retriever fails or doesn't exist, return empty documents | |
return {'documents': [], 'question': question, 'need_web_search': 'yes'} | |
def grade_docs(state): | |
"""Grades the docs generated by the retriever_db | |
If 1, returns the docs if 0 proceeds for web search""" | |
question = state['question'] | |
docs = state['documents'] | |
filterd_data = [] | |
web = "no" | |
for data in docs: | |
score = retrieval_grader.invoke({'question':question, 'documents':docs}) | |
grade = score.binary_score | |
if grade == 'yes': | |
filterd_data.append(data) | |
else: | |
#print("----------Failed, proceeding with WebSearch------------------") | |
web = 'yes' | |
return {"documents": filterd_data, "question": question, "need_web_search": web} | |
def decide(state): | |
"""Decide if the generation should be based on DB or web search DATA""" | |
web = state.get('need_web_search', 'no') # Updated key name | |
if web == 'yes': | |
return 'web_search' | |
else: | |
return 'generate' | |
def web_search(state): | |
"""Perform a web search and store both content and source URLs.""" | |
question = state['question'] | |
documents = state["documents"] | |
# Get search results | |
results = web_search_tool.invoke({"query": question}) | |
# Process results with sources | |
docs = [] | |
for res in results: | |
content = res["content"] # Extract answer content | |
source = res["url"] # Extract source URL | |
# Create Document with metadata | |
doc = Document(page_content=content, metadata={"source": source}) | |
docs.append(doc) | |
if not results: | |
#print("No results from web search. Returning default response.") | |
return {"documents": [], "question": question} | |
documents.extend(docs) | |
return {"documents": documents, "question": question} | |
def generate(state): | |
#print("Inside generate function") # Debugging | |
documents = state['documents'] | |
question = state['question'] | |
# Generate response using retrieved documents | |
response = rag_chain.invoke({'context': documents, 'question': question}) | |
# Extract source URLs | |
sources = [doc.metadata.get("source", "Unknown source") for doc in documents if "source" in doc.metadata] | |
# Format response with citations | |
formatted_response = response + "\n\nSources:\n" + "\n".join(sources) if sources else response | |
#print("Generated response:", formatted_response) # Debugging | |
# Return response with sources | |
return { | |
'documents': documents, | |
'question': question, | |
'generation': formatted_response # Append sources to the response | |
} | |
# Compile Workflow | |
workflow = StateGraph(GraphState) | |
workflow.add_node("retrieve", retrieve_db) | |
workflow.add_node("grader", grade_docs) | |
workflow.add_node("web_search", web_search) # Now this won't conflict with the state key | |
workflow.add_node("generate", generate) | |
workflow.add_edge(START, "retrieve") | |
workflow.add_edge("retrieve", "grader") | |
workflow.add_conditional_edges( | |
"grader", | |
decide, | |
{ | |
'web_search': 'web_search', | |
'generate': 'generate' | |
}, | |
) | |
workflow.add_edge("web_search", "generate") | |
workflow.add_edge("generate", END) | |
# Compile the graph | |
crag = workflow.compile() | |
# Define Gradio Interface with proper chat history management | |
def process_query(user_input, history): | |
# Initialize history if it's None | |
if history is None: | |
history = [] | |
# Add user input to history | |
history.append((user_input, "")) | |
# Process the query | |
inputs = {"question": user_input} | |
response = "" | |
try: | |
# Execute the graph | |
result = crag.invoke(inputs) | |
if result and 'generation' in result: | |
response = result['generation'] | |
else: | |
response = "I couldn't find relevant information to answer your question." | |
except Exception as e: | |
#print(f"Error in crag execution: {e}") | |
response = "I encountered an error while processing your request. Please try again." | |
# Update the last response in history | |
history[-1] = (user_input, response) | |
return history, "" | |
# Gradio Interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# 🤖 RAG-Powered Financial Advisor Chatbot") | |
chatbot = gr.Chatbot( | |
[], | |
elem_id="chatbot", | |
bubble_full_width=False, | |
height=600, | |
avatar_images=(None, "🤖") | |
) | |
with gr.Row(): | |
msg = gr.Textbox( | |
placeholder="Ask me anything about Indian financial markets...", | |
label="Your question:", | |
scale=9 | |
) | |
submit_btn = gr.Button("Send", scale=1) | |
clear_btn = gr.Button("Clear Chat") | |
# Set up event handlers | |
submit_click_event = submit_btn.click( | |
process_query, | |
inputs=[msg, chatbot], | |
outputs=[chatbot, msg] | |
) | |
msg.submit( | |
process_query, | |
inputs=[msg, chatbot], | |
outputs=[chatbot, msg] | |
) | |
clear_btn.click(lambda: [], outputs=[chatbot]) | |
if __name__ == "__main__": | |
demo.launch() |