Upload 7 files
Browse files- app.py +359 -0
- chains.py +262 -0
- entities.py +12 -0
- requirements.txt +213 -0
- sample_user_data.json +37 -0
- simple_rag.py +131 -0
- tools.py +125 -0
app.py
ADDED
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
import langchain
|
6 |
+
os.environ['STREAMLIT_SERVER_ENABLE_STATIC_SERVING'] = 'false'
|
7 |
+
|
8 |
+
from simple_rag import app
|
9 |
+
|
10 |
+
import streamlit as st
|
11 |
+
import json
|
12 |
+
from io import StringIO
|
13 |
+
import tiktoken
|
14 |
+
import time
|
15 |
+
from langchain_community.document_loaders import PyMuPDFLoader
|
16 |
+
import traceback
|
17 |
+
import sqlite3 # Import SQLite
|
18 |
+
from dotenv import load_dotenv
|
19 |
+
load_dotenv()
|
20 |
+
|
21 |
+
import uuid # Import the UUID library
|
22 |
+
|
23 |
+
# Token limits
|
24 |
+
config={"configurable": {"thread_id": "sample"}}
|
25 |
+
GPT_LIMIT = 128000
|
26 |
+
GEMINI_LIMIT = 1000000
|
27 |
+
config={"configurable": {"thread_id": "sample"}}
|
28 |
+
# Token counters
|
29 |
+
def count_tokens_gpt(text):
|
30 |
+
enc = tiktoken.encoding_for_model("gpt-4")
|
31 |
+
return len(enc.encode(text))
|
32 |
+
|
33 |
+
def count_tokens_gemini(text):
|
34 |
+
return len(text.split()) # Approximation
|
35 |
+
|
36 |
+
# Calculate tokens for the entire context window
|
37 |
+
def calculate_context_window_usage(json_data=None):
|
38 |
+
# Reconstruct the full conversation context
|
39 |
+
full_conversation = ""
|
40 |
+
for sender, message in st.session_state.chat_history:
|
41 |
+
full_conversation += f"{sender}: {message}\n\n"
|
42 |
+
|
43 |
+
# Add JSON context if provided
|
44 |
+
if json_data:
|
45 |
+
full_conversation += json.dumps(json_data)
|
46 |
+
|
47 |
+
gpt_tokens = count_tokens_gpt(full_conversation)
|
48 |
+
gemini_tokens = count_tokens_gemini(full_conversation)
|
49 |
+
|
50 |
+
return gpt_tokens, gemini_tokens
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
|
55 |
+
# Page configuration
|
56 |
+
st.set_page_config(page_title="📊 RAG Chat Assistant", layout="wide")
|
57 |
+
|
58 |
+
# --- Database setup ---
|
59 |
+
# DATABASE_PATH = "Data/chat_history.db" # Original database path
|
60 |
+
SESSION_DB_DIR = "Data/sessions" # Directory to store individual session DBs
|
61 |
+
|
62 |
+
def initialize_session_database(session_id):
|
63 |
+
"""Initializes a new database for a chat session."""
|
64 |
+
db_path = os.path.join(SESSION_DB_DIR, f"{session_id}.db")
|
65 |
+
conn = sqlite3.connect(db_path)
|
66 |
+
cursor = conn.cursor()
|
67 |
+
cursor.execute("""
|
68 |
+
CREATE TABLE IF NOT EXISTS chat_history (
|
69 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
70 |
+
sender TEXT,
|
71 |
+
message TEXT,
|
72 |
+
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
|
73 |
+
)
|
74 |
+
""")
|
75 |
+
conn.commit()
|
76 |
+
conn.close()
|
77 |
+
return db_path
|
78 |
+
|
79 |
+
def save_message(db_path, sender, message):
|
80 |
+
"""Saves a message to the specified session database."""
|
81 |
+
conn = sqlite3.connect(db_path)
|
82 |
+
cursor = conn.cursor()
|
83 |
+
cursor.execute("INSERT INTO chat_history (sender, message) VALUES (?, ?)", (sender, message))
|
84 |
+
conn.commit()
|
85 |
+
conn.close()
|
86 |
+
|
87 |
+
def clear_chat_history(db_path):
|
88 |
+
"""Clears the chat history in the specified session database."""
|
89 |
+
conn = sqlite3.connect(db_path)
|
90 |
+
cursor = conn.cursor()
|
91 |
+
cursor.execute("DELETE FROM chat_history")
|
92 |
+
conn.commit()
|
93 |
+
conn.close()
|
94 |
+
|
95 |
+
# Initialize session DB directory
|
96 |
+
if not os.path.exists(SESSION_DB_DIR):
|
97 |
+
os.makedirs(SESSION_DB_DIR)
|
98 |
+
|
99 |
+
# --- Session state setup ---
|
100 |
+
if "chat_history" not in st.session_state:
|
101 |
+
st.session_state.chat_history = [
|
102 |
+
("assistant", "👋 Hello! I'm your RAG assistant. Please upload your JSON files and ask me a question about your portfolio.")
|
103 |
+
]
|
104 |
+
if "processing" not in st.session_state:
|
105 |
+
st.session_state.processing = False
|
106 |
+
if "total_gpt_tokens" not in st.session_state:
|
107 |
+
st.session_state.total_gpt_tokens = 0 # Total accumulated
|
108 |
+
if "total_gemini_tokens" not in st.session_state:
|
109 |
+
st.session_state.total_gemini_tokens = 0 # Total accumulated
|
110 |
+
if "window_gpt_tokens" not in st.session_state:
|
111 |
+
st.session_state.window_gpt_tokens = 0 # Current context window
|
112 |
+
if "window_gemini_tokens" not in st.session_state:
|
113 |
+
st.session_state.window_gemini_tokens = 0 # Current context window
|
114 |
+
|
115 |
+
# Generate a unique session ID if one doesn't exist
|
116 |
+
if "session_id" not in st.session_state:
|
117 |
+
st.session_state.session_id = str(uuid.uuid4())
|
118 |
+
st.session_state.session_db_path = initialize_session_database(st.session_state.session_id) # Initialize session DB
|
119 |
+
|
120 |
+
# --- Load chat history from the session database ---
|
121 |
+
def load_chat_history(db_path):
|
122 |
+
conn = sqlite3.connect(db_path)
|
123 |
+
cursor = conn.cursor()
|
124 |
+
cursor.execute("SELECT sender, message FROM chat_history ORDER BY timestamp")
|
125 |
+
history = cursor.fetchall()
|
126 |
+
conn.close()
|
127 |
+
return history
|
128 |
+
|
129 |
+
|
130 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
131 |
+
|
132 |
+
# Go one level up to reach RAG_rubik/
|
133 |
+
PROJECT_ROOT = os.path.dirname(BASE_DIR)
|
134 |
+
print(PROJECT_ROOT, BASE_DIR)
|
135 |
+
# --- Layout: Chat UI Left | Progress Bars Right ---
|
136 |
+
col_chat, col_progress = st.columns([3, 1])
|
137 |
+
|
138 |
+
# --- LEFT COLUMN: Chat UI ---
|
139 |
+
with col_chat:
|
140 |
+
st.title("💬 RAG Assistant")
|
141 |
+
|
142 |
+
with st.expander("📂 Upload Required JSON Files", expanded=True):
|
143 |
+
# user_data_file = st.file_uploader("Upload user_data.json", type="json", key="user_data")
|
144 |
+
# allocations_file = st.file_uploader("Upload allocations.json", type="json", key="allocations")
|
145 |
+
|
146 |
+
user_data_path = os.getenv('USER_DATA_PATH')
|
147 |
+
allocations_path = os.getenv('ALLOCATIONS_PATH')
|
148 |
+
|
149 |
+
try:
|
150 |
+
with open(user_data_path, 'r') as f:
|
151 |
+
user_data = json.load(f)
|
152 |
+
except FileNotFoundError:
|
153 |
+
st.error(f"Error: user_data.json not found at {user_data_path}")
|
154 |
+
user_data = None
|
155 |
+
except json.JSONDecodeError:
|
156 |
+
st.error(f"Error: Could not decode user_data.json. Please ensure it is valid JSON.")
|
157 |
+
user_data = None
|
158 |
+
|
159 |
+
try:
|
160 |
+
with open(allocations_path, 'r') as f:
|
161 |
+
allocations = json.load(f)
|
162 |
+
except FileNotFoundError:
|
163 |
+
st.error(f"Error: allocations.json not found at {allocations_path}")
|
164 |
+
allocations = None
|
165 |
+
except json.JSONDecodeError:
|
166 |
+
st.error(f"Error: Could not decode allocations.json. Please ensure it is valid JSON.")
|
167 |
+
allocations = None
|
168 |
+
|
169 |
+
if user_data:
|
170 |
+
sematic = user_data.get("sematic", {})
|
171 |
+
demographic = sematic.get("demographic", {})
|
172 |
+
financial = sematic.get("financial", {})
|
173 |
+
episodic = user_data.get("episodic", {}).get("prefrences", [])
|
174 |
+
|
175 |
+
col1, col2, col3 = st.columns(3)
|
176 |
+
|
177 |
+
with col1:
|
178 |
+
st.markdown("### 🧾 **Demographic Info**")
|
179 |
+
for key, value in demographic.items():
|
180 |
+
st.markdown(f"- **{key.replace('_', ' ').title()}**: {value}")
|
181 |
+
|
182 |
+
with col2:
|
183 |
+
st.markdown("### 📊 **Financial Status**")
|
184 |
+
for key, value in financial.items():
|
185 |
+
st.markdown(f"- **{key.replace('_', ' ').title()}**: {value}")
|
186 |
+
|
187 |
+
with col3:
|
188 |
+
st.markdown("### ⚙️ **Preferences & Goals**")
|
189 |
+
st.markdown("**User Preferences:**")
|
190 |
+
for pref in user_data.get("episodic", {}).get("prefrences", []):
|
191 |
+
st.markdown(f"- {pref.capitalize()}")
|
192 |
+
st.markdown("**Goals:**")
|
193 |
+
for goal in user_data.get("episodic", {}).get("goals", []):
|
194 |
+
for k, v in goal.items():
|
195 |
+
st.markdown(f"- **{k.replace('_', ' ').title()}**: {v}")
|
196 |
+
|
197 |
+
|
198 |
+
|
199 |
+
|
200 |
+
if "allocations" not in st.session_state:
|
201 |
+
st.session_state.allocations = allocations
|
202 |
+
|
203 |
+
if st.session_state.allocations:
|
204 |
+
try:
|
205 |
+
# allocations = json.load(StringIO(allocations_file.getvalue().decode("utf-8")))
|
206 |
+
st.markdown("### 💼 Investment Allocations")
|
207 |
+
|
208 |
+
# Flatten data for display
|
209 |
+
records = []
|
210 |
+
for asset_class, entries in st.session_state.allocations.items():
|
211 |
+
for item in entries:
|
212 |
+
records.append({
|
213 |
+
"Asset Class": asset_class.replace("_", " ").title(),
|
214 |
+
"Type": item.get("type", ""),
|
215 |
+
"Label": item.get("label", ""),
|
216 |
+
"Amount (₹)": item.get("amount", 0)
|
217 |
+
})
|
218 |
+
|
219 |
+
df = pd.DataFrame(records)
|
220 |
+
st.dataframe(df)
|
221 |
+
|
222 |
+
except Exception as e:
|
223 |
+
st.error(f"Failed to parse allocations.json: {e}")
|
224 |
+
|
225 |
+
|
226 |
+
|
227 |
+
# Clear chat button
|
228 |
+
if st.button("Clear Chat"):
|
229 |
+
st.session_state.chat_history = [
|
230 |
+
("assistant", "👋 Hello! I'm your RAG assistant. Please upload your JSON files and ask me a question about your portfolio.")
|
231 |
+
]
|
232 |
+
st.session_state.total_gpt_tokens = 0
|
233 |
+
st.session_state.total_gemini_tokens = 0
|
234 |
+
st.session_state.window_gpt_tokens = 0
|
235 |
+
st.session_state.window_gemini_tokens = 0
|
236 |
+
|
237 |
+
# Clear the chat history in the session database
|
238 |
+
clear_chat_history(st.session_state.session_db_path)
|
239 |
+
|
240 |
+
|
241 |
+
st.rerun()
|
242 |
+
|
243 |
+
st.markdown("---")
|
244 |
+
|
245 |
+
# Display chat history
|
246 |
+
chat_container = st.container()
|
247 |
+
with chat_container:
|
248 |
+
for sender, message in st.session_state.chat_history:
|
249 |
+
if sender == "user":
|
250 |
+
st.chat_message("user").write(message)
|
251 |
+
else:
|
252 |
+
st.chat_message("assistant").write(message)
|
253 |
+
|
254 |
+
# Show thinking animation if processing
|
255 |
+
if st.session_state.processing:
|
256 |
+
thinking_placeholder = st.empty()
|
257 |
+
with st.chat_message("assistant"):
|
258 |
+
for i in range(3):
|
259 |
+
for dots in [".", "..", "..."]:
|
260 |
+
thinking_placeholder.markdown(f"Thinking{dots}")
|
261 |
+
time.sleep(0.3)
|
262 |
+
|
263 |
+
# Input box at the bottom
|
264 |
+
user_input = st.chat_input("Type your question...")
|
265 |
+
|
266 |
+
if user_input and not st.session_state.processing:
|
267 |
+
# Set processing flag
|
268 |
+
st.session_state.processing = True
|
269 |
+
|
270 |
+
# Add user message to history immediately
|
271 |
+
st.session_state.chat_history.append(("user", user_input))
|
272 |
+
save_message(st.session_state.session_db_path, "user", user_input) # Save user message to session DB
|
273 |
+
|
274 |
+
# Force a rerun to show the message and thinking indicator
|
275 |
+
st.rerun()
|
276 |
+
|
277 |
+
# This part runs after the rerun if we're processing
|
278 |
+
if st.session_state.processing:
|
279 |
+
if not user_data or not allocations:
|
280 |
+
st.session_state.chat_history.append(("assistant", "⚠️ Please upload both JSON files before asking questions."))
|
281 |
+
st.session_state.processing = False
|
282 |
+
st.rerun()
|
283 |
+
else:
|
284 |
+
try:
|
285 |
+
# Load JSONs
|
286 |
+
# user_data = json.load(StringIO(user_data_file.getvalue().decode("utf-8")))
|
287 |
+
# allocations = json.load(StringIO(allocations_file.getvalue().decode("utf-8")))
|
288 |
+
|
289 |
+
# Combined JSON data (for token calculation)
|
290 |
+
combined_json_data = {"user_data": user_data, "allocations": allocations}
|
291 |
+
|
292 |
+
# Get the last user message
|
293 |
+
last_user_message = next((msg for sender, msg in reversed(st.session_state.chat_history) if sender == "user"), "")
|
294 |
+
|
295 |
+
# Count tokens for this user message
|
296 |
+
user_msg_gpt_tokens = count_tokens_gpt(last_user_message)
|
297 |
+
user_msg_gemini_tokens = count_tokens_gemini(last_user_message)
|
298 |
+
|
299 |
+
# Add to accumulated totals
|
300 |
+
st.session_state.total_gpt_tokens += user_msg_gpt_tokens
|
301 |
+
st.session_state.total_gemini_tokens += user_msg_gemini_tokens
|
302 |
+
|
303 |
+
# Calculate context window usage (conversation + JSON data)
|
304 |
+
window_gpt, window_gemini = calculate_context_window_usage(combined_json_data)
|
305 |
+
st.session_state.window_gpt_tokens = window_gpt
|
306 |
+
st.session_state.window_gemini_tokens = window_gemini
|
307 |
+
|
308 |
+
# Check token limits for context window
|
309 |
+
if window_gpt > GPT_LIMIT or window_gemini > GEMINI_LIMIT:
|
310 |
+
st.session_state.chat_history.append(("assistant", "⚠️ Your conversation has exceeded token limits. Please clear the chat to continue."))
|
311 |
+
st.session_state.processing = False
|
312 |
+
st.rerun()
|
313 |
+
else:
|
314 |
+
# --- Call LangGraph ---
|
315 |
+
inputs = {
|
316 |
+
"query": last_user_message,
|
317 |
+
"user_data": user_data,
|
318 |
+
"allocations": allocations,
|
319 |
+
#"data":"",
|
320 |
+
"chat_history": st.session_state.chat_history
|
321 |
+
}
|
322 |
+
print(st.session_state.chat_history)
|
323 |
+
|
324 |
+
|
325 |
+
|
326 |
+
output = app.invoke(inputs, config = config)
|
327 |
+
response = output.get('output')
|
328 |
+
print(response)
|
329 |
+
|
330 |
+
|
331 |
+
# Check if the response contains allocation updates
|
332 |
+
if "allocations" in output:
|
333 |
+
st.session_state.allocations = output["allocations"]
|
334 |
+
|
335 |
+
# Count tokens for the response
|
336 |
+
response_gpt_tokens = count_tokens_gpt(response)
|
337 |
+
response_gemini_tokens = count_tokens_gemini(response)
|
338 |
+
|
339 |
+
# Add to accumulated totals
|
340 |
+
st.session_state.total_gpt_tokens += response_gpt_tokens
|
341 |
+
st.session_state.total_gemini_tokens += response_gemini_tokens
|
342 |
+
|
343 |
+
# Add to chat history
|
344 |
+
st.session_state.chat_history.append(("assistant", response))
|
345 |
+
|
346 |
+
# Update context window calculations after adding response
|
347 |
+
window_gpt, window_gemini = calculate_context_window_usage(combined_json_data)
|
348 |
+
st.session_state.window_gpt_tokens = window_gpt
|
349 |
+
st.session_state.window_gemini_tokens = window_gemini
|
350 |
+
|
351 |
+
except Exception as e:
|
352 |
+
tb = traceback.extract_stack()
|
353 |
+
filename, line_number, function_name, text = tb[-2]
|
354 |
+
error_message = f"❌ Error: {str(e)} in {filename} at line {line_number}, function: {function_name}"
|
355 |
+
st.session_state.chat_history.append(("assistant", error_message))
|
356 |
+
|
357 |
+
# Reset processing flag
|
358 |
+
st.session_state.processing = False
|
359 |
+
st.rerun()
|
chains.py
ADDED
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""All prompts utilized by the RAG pipeline"""
|
2 |
+
from langchain_core.prompts import ChatPromptTemplate, PromptTemplate, SystemMessagePromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder
|
3 |
+
from langchain_core.output_parsers import StrOutputParser
|
4 |
+
from langchain_openai import ChatOpenAI
|
5 |
+
from pydantic import BaseModel, Field
|
6 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
7 |
+
import os
|
8 |
+
from tools import json_to_table, goal_feasibility, save_data, rag_tool
|
9 |
+
from langchain.agents import initialize_agent, Tool
|
10 |
+
from langchain.agents import AgentType
|
11 |
+
from langgraph.prebuilt import create_react_agent
|
12 |
+
from langchain.tools import Tool
|
13 |
+
from dotenv import load_dotenv
|
14 |
+
load_dotenv()
|
15 |
+
|
16 |
+
|
17 |
+
gemini = ChatGoogleGenerativeAI(model = 'gemini-2.0-flash')
|
18 |
+
llm = ChatOpenAI(
|
19 |
+
model='gpt-4.1-nano',
|
20 |
+
api_key=os.environ.get('OPEN_AI_KEY'),
|
21 |
+
temperature=0.2
|
22 |
+
)
|
23 |
+
|
24 |
+
# Schema for grading documents
|
25 |
+
class GradeDocuments(BaseModel):
|
26 |
+
binary_score: str = Field(description="Documents are relevant to the question, 'yes' or 'no'")
|
27 |
+
|
28 |
+
structured_llm_grader = llm.with_structured_output(GradeDocuments)
|
29 |
+
system = """You are a grader assessing relevance of a retrieved document to a user question.
|
30 |
+
If the document contains keyword(s) or semantic meaning related to the question, grade it as relevant.
|
31 |
+
Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question."""
|
32 |
+
|
33 |
+
grade_prompt = ChatPromptTemplate.from_messages([
|
34 |
+
("system", system),
|
35 |
+
("human", "Retrieved document: \n\n {data} \n\n User question: {query}")
|
36 |
+
])
|
37 |
+
|
38 |
+
retrieval_grader = grade_prompt | structured_llm_grader
|
39 |
+
|
40 |
+
|
41 |
+
prompt = PromptTemplate(
|
42 |
+
template='''
|
43 |
+
You are a SEBI-Registered Investment Advisor (RIA) specializing in Indian financial markets and client relationship management.
|
44 |
+
|
45 |
+
Your task is to understand and respond to the user's financial query using the following inputs:
|
46 |
+
- Query: {query}
|
47 |
+
- Documents: {data}
|
48 |
+
- User Profile: {user_data}
|
49 |
+
- Savings Allocations: {allocations}
|
50 |
+
|
51 |
+
|
52 |
+
|
53 |
+
Instructions:
|
54 |
+
1. Understand the User's Intent: Carefully interpret what the user is asking about their investments.
|
55 |
+
2. Analyze Allocations: Evaluate the savings allocation data to understand the user's current financial posture.
|
56 |
+
3. Personalized Response:
|
57 |
+
- If detailed user profile and allocation data are available, prioritize your response based on this data.
|
58 |
+
- If profile or allocation data is sparse, rely more heavily on the query context.
|
59 |
+
4. Use Supporting Documents: Extract relevant insights from the provided documents ({data}) to support your answer.
|
60 |
+
5. When Unsure: If the documents or data do not contain the necessary information, say "I don't know" rather than guessing.
|
61 |
+
|
62 |
+
Always aim to give a response that is:
|
63 |
+
- Data-informed
|
64 |
+
- Client-centric
|
65 |
+
- Aligned with Indian financial regulations and norms
|
66 |
+
|
67 |
+
|
68 |
+
''',
|
69 |
+
input_variables=['query', 'data', 'user_data', 'allocations']
|
70 |
+
)
|
71 |
+
|
72 |
+
rag_chain = prompt | gemini | StrOutputParser()
|
73 |
+
|
74 |
+
|
75 |
+
# Prompt
|
76 |
+
system_rewrite = """You a question re-writer that converts an input question to a better version that is optimized \n
|
77 |
+
for web search. Look at the input and try to reason about the underlying semantic intent / meaning."""
|
78 |
+
re_write_prompt = ChatPromptTemplate.from_messages(
|
79 |
+
[
|
80 |
+
("system", system_rewrite),
|
81 |
+
(
|
82 |
+
"human",
|
83 |
+
"Here is the initial question: \n\n {query} \n Formulate an improved question.",
|
84 |
+
),
|
85 |
+
]
|
86 |
+
)
|
87 |
+
|
88 |
+
question_rewriter = re_write_prompt | llm | StrOutputParser()
|
89 |
+
|
90 |
+
|
91 |
+
from pydantic import BaseModel, Field, RootModel
|
92 |
+
from typing import Dict
|
93 |
+
from langchain_core.output_parsers import JsonOutputParser
|
94 |
+
|
95 |
+
# Define the Pydantic model using RootModel
|
96 |
+
class CategoryProbabilities(RootModel):
|
97 |
+
"""Probabilities for different knowledge base categories."""
|
98 |
+
root: Dict[str, float] = Field(description="Dictionary mapping category names to probability scores")
|
99 |
+
|
100 |
+
system_classifier = """You are a query classifier that determines the most relevant knowledge bases (KBs) for a given user query.
|
101 |
+
Analyze the semantic meaning and intent of the query and assign probability scores (between 0 and 1) to each KB.
|
102 |
+
|
103 |
+
Ensure the probabilities sum to 1 and output a JSON dictionary with category names as keys and probabilities as values.
|
104 |
+
"""
|
105 |
+
|
106 |
+
classification_prompt = ChatPromptTemplate.from_messages(
|
107 |
+
[
|
108 |
+
("system", system_classifier),
|
109 |
+
(
|
110 |
+
"human",
|
111 |
+
"Here is the user query: \n\n {query} \n\n Assign probability scores to each of the following KBs:\n"
|
112 |
+
"{categories}\n\nReturn a JSON object with category names as keys and probability scores as values."
|
113 |
+
),
|
114 |
+
]
|
115 |
+
)
|
116 |
+
|
117 |
+
# Create a JSON output parser
|
118 |
+
json_parser = JsonOutputParser(pydantic_object=CategoryProbabilities)
|
119 |
+
|
120 |
+
# Create the chain with the structured output parser
|
121 |
+
query_classifier = classification_prompt | llm | json_parser
|
122 |
+
|
123 |
+
|
124 |
+
#query_classifier = classification_prompt | llm | StrOutputParser()
|
125 |
+
|
126 |
+
"""
|
127 |
+
name: str
|
128 |
+
|
129 |
+
position: Dict[str, int]
|
130 |
+
riskiness: int
|
131 |
+
illiquidity: int
|
132 |
+
|
133 |
+
amount: float
|
134 |
+
currency: str = "inr"
|
135 |
+
percentage: float
|
136 |
+
explanation: Dict[str, str]
|
137 |
+
|
138 |
+
assets: List[AssetAllocation]
|
139 |
+
"""
|
140 |
+
#--------------------------------------------------------------------------------------
|
141 |
+
tools = [
|
142 |
+
{
|
143 |
+
"type": "function",
|
144 |
+
"function": {
|
145 |
+
"name": "json_to_table",
|
146 |
+
"description": "Convert JSON data to a markdown table. Use when user asks to visualise or tabulate structured data.",
|
147 |
+
"parameters": {
|
148 |
+
"type": "object",
|
149 |
+
"properties": {
|
150 |
+
"arguments": {
|
151 |
+
"type": "object",
|
152 |
+
"properties": {
|
153 |
+
"json_data": {
|
154 |
+
"type": "object",
|
155 |
+
"description": "The JSON data to convert to a table"
|
156 |
+
}
|
157 |
+
},
|
158 |
+
"required": ["json_data"]
|
159 |
+
}
|
160 |
+
},
|
161 |
+
"required": ["arguments"]
|
162 |
+
}
|
163 |
+
}
|
164 |
+
},
|
165 |
+
{
|
166 |
+
"type": "function",
|
167 |
+
"function": {
|
168 |
+
"name": "rag_tool",
|
169 |
+
"description": "Lets the agent use RAG system as a tool",
|
170 |
+
"parameters": {
|
171 |
+
"type": "object",
|
172 |
+
"properties": {
|
173 |
+
"arguments": {
|
174 |
+
"type": "object",
|
175 |
+
"properties": {
|
176 |
+
"query": {
|
177 |
+
"type": "string",
|
178 |
+
"description": "The query to search for in the RAG system"
|
179 |
+
}
|
180 |
+
},
|
181 |
+
"required": ["query"]
|
182 |
+
}
|
183 |
+
},
|
184 |
+
"required": ["arguments"]
|
185 |
+
}
|
186 |
+
}
|
187 |
+
}
|
188 |
+
]
|
189 |
+
|
190 |
+
|
191 |
+
template = '''You are a SEBI-Registered Investment Advisor (RIA) specializing in Indian financial markets and client relationship management.
|
192 |
+
|
193 |
+
Your task is to understand and respond to the user's financial query using the following inputs:
|
194 |
+
- Query: {query}
|
195 |
+
- User Profile: {user_data}
|
196 |
+
- Savings Allocations: {allocations}
|
197 |
+
- Chat History: {chat_history}
|
198 |
+
- 🔎 Retrieved Context (optional): {retrieved_context}
|
199 |
+
|
200 |
+
Instructions:
|
201 |
+
1. **Understand the User's Intent**: Carefully interpret what the user is asking about their investments. If a user input contradicts previously stated preferences or profile attributes (e.g., low risk appetite or crypto aversion), ask a clarifying question before proceeding. Do not update allocations or goals unless the user confirms the change explicitly.
|
202 |
+
2. **Analyze Allocations**: Evaluate the savings allocation data to understand the user's current financial posture.
|
203 |
+
3. **Use Retrieved Context**: If any contextual information is provided in `retrieved_context`, leverage it to improve your response quality and relevance.
|
204 |
+
4. **Always Update Information**: If the user shares any new demographic, financial, or preference-related data, update the user profile accordingly. If they request changes in their allocations, ensure the changes are applied **proportionally** and that the total allocation always sums to 100%.
|
205 |
+
5. **IMPORTANT: When displaying or updating allocations, you MUST format the data as a Markdown table and always display allocations as a table only** using the following columns:
|
206 |
+
- Asset Class
|
207 |
+
- Type
|
208 |
+
- Label
|
209 |
+
- Old Amount (₹)
|
210 |
+
- Change (₹)
|
211 |
+
- New Amount (₹)
|
212 |
+
- Justification
|
213 |
+
|
214 |
+
|
215 |
+
7. **Maintain Conversational Memory**: Ensure updates are passed to memory using the specified `updates` structure.
|
216 |
+
8. **Tool Use Policy**:
|
217 |
+
- ✅ Use `rag_tool` for retrieving **external financial knowledge or regulation** context when necessary.
|
218 |
+
|
219 |
+
|
220 |
+
---
|
221 |
+
|
222 |
+
### 🎯 Response Style Guide:
|
223 |
+
|
224 |
+
- 📝 Keep it under 300 words.
|
225 |
+
- 😊 Friendly tone: be warm and helpful.
|
226 |
+
- 📚 Structured: use bullet points, short paragraphs, and headers.
|
227 |
+
- 👀 Visually clear: break sections clearly.
|
228 |
+
- 🌟 Use emojis to guide attention and convey tone.
|
229 |
+
- 🎯 Be direct and focused on the user's request.
|
230 |
+
|
231 |
+
---
|
232 |
+
|
233 |
+
### 🔁 If There Are Allocation Changes:
|
234 |
+
|
235 |
+
You **must** display a Markdown table as per the format above. Then, return memory update instructions using this JSON structure:
|
236 |
+
```json
|
237 |
+
{{
|
238 |
+
"updates": {{
|
239 |
+
"user_data": {{ ... }}, // Include only changed fields
|
240 |
+
"allocations": {{...}} // Include only changed rows
|
241 |
+
}}
|
242 |
+
}}
|
243 |
+
'''
|
244 |
+
|
245 |
+
# Create the prompt template
|
246 |
+
simple_prompt = ChatPromptTemplate.from_messages([
|
247 |
+
SystemMessagePromptTemplate.from_template(template=template),
|
248 |
+
MessagesPlaceholder(variable_name="chat_history", optional=True),
|
249 |
+
HumanMessagePromptTemplate.from_template("User Query: {query}"),
|
250 |
+
HumanMessagePromptTemplate.from_template("Current User Profile:\n{user_data}"),
|
251 |
+
HumanMessagePromptTemplate.from_template("Current Allocations:\n{allocations}"),
|
252 |
+
HumanMessagePromptTemplate.from_template("🔎 Retrieved Context (if any):\n{retrieved_context}"),
|
253 |
+
])
|
254 |
+
|
255 |
+
# Create the chain with direct tool binding
|
256 |
+
llm = ChatOpenAI(
|
257 |
+
temperature=0.1,
|
258 |
+
model="gpt-4.1-nano",
|
259 |
+
|
260 |
+
)
|
261 |
+
llm_with_tools = llm.bind_tools(tools)
|
262 |
+
simple_chain = simple_prompt | llm_with_tools
|
entities.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import IntEnum, StrEnum, Enum
|
2 |
+
|
3 |
+
class KBCategory(str, Enum):
|
4 |
+
ProductCategory = "product_category"
|
5 |
+
InvestmentRegulations = "investment_regulations"
|
6 |
+
TaxationDetails = "taxation_details"
|
7 |
+
MarketSegments = "market_segments"
|
8 |
+
CulturalAspects = "cultural_aspects"
|
9 |
+
General = "general"
|
10 |
+
|
11 |
+
class THRESHOLD(Enum):
|
12 |
+
threshold = 0.2
|
requirements.txt
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==2.2.2
|
2 |
+
agents==1.4.0
|
3 |
+
aiohappyeyeballs==2.6.1
|
4 |
+
aiohttp==3.10.11
|
5 |
+
aiosignal==1.3.2
|
6 |
+
altair==5.5.0
|
7 |
+
annotated-types==0.7.0
|
8 |
+
anyio==4.9.0
|
9 |
+
asttokens==3.0.0
|
10 |
+
astunparse==1.6.3
|
11 |
+
attrs==25.3.0
|
12 |
+
beautifulsoup4==4.12.2
|
13 |
+
blinker==1.9.0
|
14 |
+
cachetools==5.5.2
|
15 |
+
certifi==2025.1.31
|
16 |
+
charset-normalizer==3.4.1
|
17 |
+
click==8.1.3
|
18 |
+
cloudpickle==3.1.1
|
19 |
+
colorama==0.4.6
|
20 |
+
contourpy==1.3.1
|
21 |
+
cycler==0.12.1
|
22 |
+
dataclasses-json==0.6.7
|
23 |
+
decorator==5.2.1
|
24 |
+
distro==1.9.0
|
25 |
+
executing==2.2.0
|
26 |
+
fastapi==0.115.12
|
27 |
+
filelock==3.18.0
|
28 |
+
filetype==1.2.0
|
29 |
+
flask==2.2.3
|
30 |
+
flask-httpauth==4.8.0
|
31 |
+
flatbuffers==25.2.10
|
32 |
+
fonttools==4.57.0
|
33 |
+
frozenlist==1.5.0
|
34 |
+
fsspec==2025.3.2
|
35 |
+
gast==0.6.0
|
36 |
+
gitdb==4.0.12
|
37 |
+
gitpython==3.1.44
|
38 |
+
google-ai-generativelanguage==0.6.17
|
39 |
+
google-api-core==2.24.2
|
40 |
+
google-auth==2.38.0
|
41 |
+
google-pasta==0.2.0
|
42 |
+
googleapis-common-protos==1.69.2
|
43 |
+
grandalf==0.8
|
44 |
+
graphviz==0.20.3
|
45 |
+
greenlet==3.1.1
|
46 |
+
griffe==1.7.2
|
47 |
+
grpcio==1.71.0
|
48 |
+
grpcio-status==1.71.0
|
49 |
+
gunicorn==21.2.0
|
50 |
+
gym==0.26.2
|
51 |
+
gym-notices==0.0.8
|
52 |
+
h11==0.14.0
|
53 |
+
h5py==3.13.0
|
54 |
+
httpcore==1.0.7
|
55 |
+
httpx==0.28.1
|
56 |
+
httpx-sse==0.4.0
|
57 |
+
huggingface-hub==0.30.1
|
58 |
+
idna==3.10
|
59 |
+
iniconfig==2.1.0
|
60 |
+
ipython==9.0.2
|
61 |
+
ipython-pygments-lexers==1.1.1
|
62 |
+
itsdangerous==2.1.2
|
63 |
+
jedi==0.19.2
|
64 |
+
jinja2==3.1.2
|
65 |
+
jiter==0.9.0
|
66 |
+
joblib==1.4.2
|
67 |
+
jsonpatch==1.33
|
68 |
+
jsonpointer==3.0.0
|
69 |
+
jsonschema==4.23.0
|
70 |
+
jsonschema-specifications==2024.10.1
|
71 |
+
keras==3.9.2
|
72 |
+
kiwisolver==1.4.8
|
73 |
+
langchain==0.3.21
|
74 |
+
langchain-community==0.3.20
|
75 |
+
langchain-core==0.3.49
|
76 |
+
langchain-google-genai==2.1.2
|
77 |
+
langchain-huggingface==0.1.2
|
78 |
+
langchain-openai==0.3.11
|
79 |
+
langchain-pinecone==0.2.3
|
80 |
+
langchain-tests==0.3.17
|
81 |
+
langchain-text-splitters==0.3.7
|
82 |
+
langgraph==0.3.21
|
83 |
+
langgraph-checkpoint==2.0.23
|
84 |
+
langgraph-prebuilt==0.1.7
|
85 |
+
langgraph-sdk==0.1.60
|
86 |
+
langsmith==0.3.19
|
87 |
+
libclang==18.1.1
|
88 |
+
markdown==3.7
|
89 |
+
markdown-it-py==3.0.0
|
90 |
+
markupsafe==2.1.2
|
91 |
+
marshmallow==3.26.1
|
92 |
+
matplotlib==3.10.1
|
93 |
+
matplotlib-inline==0.1.7
|
94 |
+
mcp==1.6.0
|
95 |
+
mdurl==0.1.2
|
96 |
+
ml-dtypes==0.5.1
|
97 |
+
mpmath==1.3.0
|
98 |
+
multidict==6.3.1
|
99 |
+
mypy-extensions==1.0.0
|
100 |
+
namex==0.0.8
|
101 |
+
narwhals==1.34.1
|
102 |
+
networkx==3.4.2
|
103 |
+
numpy==1.26.4
|
104 |
+
nvidia-cublas-cu12==12.4.5.8
|
105 |
+
nvidia-cuda-cupti-cu12==12.4.127
|
106 |
+
nvidia-cuda-nvrtc-cu12==12.4.127
|
107 |
+
nvidia-cuda-runtime-cu12==12.4.127
|
108 |
+
nvidia-cudnn-cu12==9.1.0.70
|
109 |
+
nvidia-cufft-cu12==11.2.1.3
|
110 |
+
nvidia-curand-cu12==10.3.5.147
|
111 |
+
nvidia-cusolver-cu12==11.6.1.9
|
112 |
+
nvidia-cusparse-cu12==12.3.1.170
|
113 |
+
nvidia-cusparselt-cu12==0.6.2
|
114 |
+
nvidia-nccl-cu12==2.21.5
|
115 |
+
nvidia-nvjitlink-cu12==12.4.127
|
116 |
+
nvidia-nvtx-cu12==12.4.127
|
117 |
+
openai==1.69.0
|
118 |
+
openai-agents==0.0.10
|
119 |
+
opt-einsum==3.4.0
|
120 |
+
optree==0.15.0
|
121 |
+
orjson==3.10.16
|
122 |
+
ormsgpack==1.9.1
|
123 |
+
packaging==24.2
|
124 |
+
pandas==2.2.3
|
125 |
+
parso==0.8.4
|
126 |
+
pexpect==4.9.0
|
127 |
+
pillow==11.1.0
|
128 |
+
pinecone==5.4.2
|
129 |
+
pinecone-plugin-inference==3.1.0
|
130 |
+
pinecone-plugin-interface==0.0.7
|
131 |
+
pluggy==1.5.0
|
132 |
+
prompt-toolkit==3.0.50
|
133 |
+
propcache==0.3.1
|
134 |
+
proto-plus==1.26.1
|
135 |
+
protobuf==5.29.4
|
136 |
+
ptyprocess==0.7.0
|
137 |
+
pure-eval==0.2.3
|
138 |
+
pyarrow==19.0.1
|
139 |
+
pyasn1==0.6.1
|
140 |
+
pyasn1-modules==0.4.2
|
141 |
+
pydantic==2.11.1
|
142 |
+
pydantic-core==2.33.0
|
143 |
+
pydantic-settings==2.8.1
|
144 |
+
pydeck==0.9.1
|
145 |
+
pygments==2.19.1
|
146 |
+
pymupdf==1.25.5
|
147 |
+
pyparsing==3.2.3
|
148 |
+
pypd==1.1.0
|
149 |
+
pypdf==5.4.0
|
150 |
+
pytest==8.3.5
|
151 |
+
pytest-asyncio==0.26.0
|
152 |
+
pytest-mock==3.14.0
|
153 |
+
pytest-socket==0.7.0
|
154 |
+
python-dateutil==2.9.0.post0
|
155 |
+
python-dotenv==1.0.0
|
156 |
+
pytz==2025.2
|
157 |
+
pyyaml==6.0.2
|
158 |
+
referencing==0.36.2
|
159 |
+
regex==2024.11.6
|
160 |
+
requests==2.31.0
|
161 |
+
requests-toolbelt==1.0.0
|
162 |
+
rich==14.0.0
|
163 |
+
rpds-py==0.24.0
|
164 |
+
rsa==4.9
|
165 |
+
ruamel-yaml==0.18.10
|
166 |
+
ruamel-yaml-clib==0.2.12
|
167 |
+
safetensors==0.5.3
|
168 |
+
scikit-learn==1.6.1
|
169 |
+
scipy==1.15.2
|
170 |
+
sentence-transformers==4.0.1
|
171 |
+
setuptools==78.1.0
|
172 |
+
six==1.17.0
|
173 |
+
smmap==5.0.2
|
174 |
+
sniffio==1.3.1
|
175 |
+
soupsieve==2.6
|
176 |
+
sqlalchemy==2.0.40
|
177 |
+
sse-starlette==2.2.1
|
178 |
+
stack-data==0.6.3
|
179 |
+
starlette==0.46.1
|
180 |
+
streamlit==1.44.1
|
181 |
+
sympy==1.13.1
|
182 |
+
syrupy==4.9.1
|
183 |
+
tabulate==0.9.0
|
184 |
+
tenacity==9.0.0
|
185 |
+
tensorboard==2.19.0
|
186 |
+
tensorboard-data-server==0.7.2
|
187 |
+
|
188 |
+
termcolor==3.0.1
|
189 |
+
threadpoolctl==3.6.0
|
190 |
+
tiktoken==0.9.0
|
191 |
+
tokenizers==0.21.1
|
192 |
+
toml==0.10.2
|
193 |
+
torch==2.6.0
|
194 |
+
tornado==6.4.2
|
195 |
+
tqdm==4.67.1
|
196 |
+
traitlets==5.14.3
|
197 |
+
transformers==4.50.3
|
198 |
+
triton==3.2.0
|
199 |
+
types-requests==2.32.0.20250328
|
200 |
+
typing-extensions==4.13.0
|
201 |
+
typing-inspect==0.9.0
|
202 |
+
typing-inspection==0.4.0
|
203 |
+
tzdata==2025.2
|
204 |
+
urllib3==2.3.0
|
205 |
+
uvicorn==0.34.0
|
206 |
+
watchdog==6.0.0
|
207 |
+
wcwidth==0.2.13
|
208 |
+
werkzeug==2.2.3
|
209 |
+
wheel==0.45.1
|
210 |
+
wrapt==1.17.2
|
211 |
+
xxhash==3.5.0
|
212 |
+
yarl==1.18.3
|
213 |
+
zstandard==0.23.0
|
sample_user_data.json
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"sematic": {
|
3 |
+
"demographic": {
|
4 |
+
"age": 30,
|
5 |
+
"employment_type": "salaried",
|
6 |
+
"dependents": 0,
|
7 |
+
"health_status": "good",
|
8 |
+
"risk_appetite": 0,
|
9 |
+
"financial_maturity": 0,
|
10 |
+
"location": "tier_1"
|
11 |
+
|
12 |
+
},
|
13 |
+
"financial": {
|
14 |
+
"salary": 100000,
|
15 |
+
"business_value": 0,
|
16 |
+
"current_savings_and_investments": 1000000,
|
17 |
+
"debts": 0,
|
18 |
+
"market_outlook": "neutral",
|
19 |
+
"include_insights": true,
|
20 |
+
"is_housing_loan": false,
|
21 |
+
"monthly_expenses": 50000,
|
22 |
+
"property_value": 0,
|
23 |
+
"real_estate_type": "tier_1_residential",
|
24 |
+
"real_estate_value": 0,
|
25 |
+
"region": "ind",
|
26 |
+
"savings_percentage": 20
|
27 |
+
}
|
28 |
+
},
|
29 |
+
"episodic": {
|
30 |
+
|
31 |
+
"prefrences": [
|
32 |
+
"doesn't like crypto",
|
33 |
+
"doesn't want to be exposed to energy sector too much"
|
34 |
+
]
|
35 |
+
}
|
36 |
+
}
|
37 |
+
|
simple_rag.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
|
4 |
+
from langgraph.graph import START, END, StateGraph
|
5 |
+
from langchain_openai import OpenAIEmbeddings
|
6 |
+
from chains import simple_chain, llm_with_tools
|
7 |
+
from langchain_core.messages import BaseMessage, HumanMessage, ToolMessage, AIMessage
|
8 |
+
from typing import TypedDict, Optional, Dict, List, Union, Annotated
|
9 |
+
from langchain_core.messages import AnyMessage #human or AI message
|
10 |
+
from langgraph.graph.message import add_messages # reducer in langgraph
|
11 |
+
from langgraph.prebuilt import ToolNode, tools_condition
|
12 |
+
from langchain.agents import initialize_agent, Tool
|
13 |
+
from langchain.agents.agent_types import AgentType
|
14 |
+
from langgraph.checkpoint.memory import MemorySaver
|
15 |
+
import json
|
16 |
+
import langchain
|
17 |
+
from tools import json_to_table, goal_feasibility, rag_tool, save_data
|
18 |
+
import re
|
19 |
+
|
20 |
+
from dotenv import load_dotenv
|
21 |
+
load_dotenv()
|
22 |
+
|
23 |
+
memory = MemorySaver()
|
24 |
+
config = {"thread_id":"sample"}
|
25 |
+
tools = [json_to_table, rag_tool]
|
26 |
+
#tool_executor = ToolExecutor([json_to_table, goal_feasibility])
|
27 |
+
json_to_table_node = ToolNode([json_to_table])
|
28 |
+
|
29 |
+
rag_tool_node = ToolNode([rag_tool])
|
30 |
+
class Graph(TypedDict):
|
31 |
+
query: Annotated[list[AnyMessage], add_messages]
|
32 |
+
#chat_history : List[BaseMessage]
|
33 |
+
user_data : Dict
|
34 |
+
allocations : Dict
|
35 |
+
#data : str
|
36 |
+
output : Dict
|
37 |
+
retrieved_context: str
|
38 |
+
|
39 |
+
def chat(state):
|
40 |
+
inputs = {
|
41 |
+
"query": state["query"],
|
42 |
+
"user_data": state["user_data"],
|
43 |
+
"allocations": state["allocations"],
|
44 |
+
#"data": state["data"],
|
45 |
+
"chat_history": state["query"], # If you treat `query` as history
|
46 |
+
"retrieved_context": state.get("retrieved_context", "")
|
47 |
+
}
|
48 |
+
|
49 |
+
result = simple_chain.invoke(inputs)
|
50 |
+
#print(result)
|
51 |
+
|
52 |
+
return {
|
53 |
+
"query": state["query"],
|
54 |
+
"user_data": state["user_data"],
|
55 |
+
"allocations": state["allocations"],
|
56 |
+
#"data": state["data"],
|
57 |
+
"retrieved_context": "", # clear after use
|
58 |
+
"output": result.content
|
59 |
+
}
|
60 |
+
|
61 |
+
def json_to_table_node(state):
|
62 |
+
tool_output = json_to_table(state["allocations"]) # Or whatever your input is
|
63 |
+
return AIMessage(content=tool_output)
|
64 |
+
|
65 |
+
def tools_condition(state):
|
66 |
+
last_message = state["query"][-1] # Last user or AI message
|
67 |
+
if isinstance(last_message, AIMessage):
|
68 |
+
tool_calls = getattr(last_message, "tool_calls", None)
|
69 |
+
|
70 |
+
# Check if tool calls exist and handle them
|
71 |
+
if tool_calls:
|
72 |
+
tool_name = tool_calls[0].get('name', '') # Safely access the tool name
|
73 |
+
|
74 |
+
if tool_name == "json_to_table":
|
75 |
+
return "show_allocation_table"
|
76 |
+
|
77 |
+
elif tool_name == "rag_tool":
|
78 |
+
return "query_rag"
|
79 |
+
else:
|
80 |
+
return "tools" # Fallback in case of unknown tool names
|
81 |
+
return "END" # End the flow if no tool calls are found
|
82 |
+
|
83 |
+
|
84 |
+
# ---- GRAPH SETUP ----
|
85 |
+
graph = StateGraph(Graph)
|
86 |
+
|
87 |
+
# Nodes
|
88 |
+
graph.add_node("chat", chat)
|
89 |
+
graph.add_node("show_allocation_table", json_to_table_node)
|
90 |
+
#graph.add_node("save_data_info", save_data_node)
|
91 |
+
graph.add_node("query_rag", rag_tool_node)
|
92 |
+
graph.add_node("tool_output_to_message", lambda state: AIMessage(content=state["tool_output"]))
|
93 |
+
|
94 |
+
|
95 |
+
#graph.add_node("tools", ToolNode(tools)) # fallback for other tools
|
96 |
+
|
97 |
+
|
98 |
+
# Main flow
|
99 |
+
graph.add_edge(START, "chat")
|
100 |
+
graph.add_conditional_edges("chat", tools_condition)
|
101 |
+
|
102 |
+
# Each tool goes back to chat
|
103 |
+
graph.add_edge("show_allocation_table", "chat")
|
104 |
+
#graph.add_edge("save_data_info", "chat")
|
105 |
+
graph.add_edge("query_rag", "chat")
|
106 |
+
|
107 |
+
# End after a loop
|
108 |
+
graph.add_edge("chat", END)
|
109 |
+
|
110 |
+
|
111 |
+
# Compile
|
112 |
+
app = graph.compile(checkpointer=memory)
|
113 |
+
|
114 |
+
'''
|
115 |
+
with open('/home/pavan/Desktop/FOLDERS/RUBIC/RAG_without_profiler/RAG_rubik/sample_data/sample_alloc.json', 'r') as f:
|
116 |
+
data = json.load(f)
|
117 |
+
with open('/home/pavan/Desktop/FOLDERS/RUBIC/RAG_without_profiler/RAG_rubik/sample_data/sample_alloc.json', 'r') as f:
|
118 |
+
allocs = json.load(f)
|
119 |
+
inputs = {
|
120 |
+
"query":"display my investments.",
|
121 |
+
"user_data":data,
|
122 |
+
"allocations":allocs,
|
123 |
+
"data":"",
|
124 |
+
"chat_history": [],
|
125 |
+
|
126 |
+
}
|
127 |
+
|
128 |
+
langchain.debug = True
|
129 |
+
print(app.invoke(inputs, config={"configurable": {"thread_id": "sample"}}).get('output'))
|
130 |
+
#print(json_to_table.args_schema.model_json_schema())
|
131 |
+
'''
|
tools.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
4 |
+
from langchain.tools import tool
|
5 |
+
import pandas as pd
|
6 |
+
import json
|
7 |
+
import re
|
8 |
+
from copy import deepcopy
|
9 |
+
from langchain_pinecone import PineconeVectorStore
|
10 |
+
from dotenv import load_dotenv
|
11 |
+
load_dotenv()
|
12 |
+
from langchain_openai import OpenAIEmbeddings
|
13 |
+
from pydantic import BaseModel
|
14 |
+
from typing import Any, Optional
|
15 |
+
|
16 |
+
api_key = os.getenv('PINCEONE_API_KEY')
|
17 |
+
|
18 |
+
class JsonToTableInput(BaseModel):
|
19 |
+
json_data: Any
|
20 |
+
|
21 |
+
class RagToolInput(BaseModel):
|
22 |
+
query: str
|
23 |
+
|
24 |
+
# Define the tools with proper validation
|
25 |
+
def json_to_table(input_data: JsonToTableInput):
|
26 |
+
"""Convert JSON data to a markdown table. Use when user asks to visualise or tabulate structured data."""
|
27 |
+
json_data = input_data.json_data
|
28 |
+
|
29 |
+
if isinstance(json_data, str):
|
30 |
+
try:
|
31 |
+
json_data = json.loads(json_data)
|
32 |
+
except:
|
33 |
+
# If json_data has parsing issues, try to work with it directly
|
34 |
+
pass
|
35 |
+
|
36 |
+
# Handle a common case in the prompt where 'allocations' might be a nested key
|
37 |
+
if isinstance(json_data, dict) and 'allocations' in json_data:
|
38 |
+
json_data = json_data['allocations']
|
39 |
+
|
40 |
+
# Ensure we have a valid list or dict to convert to DataFrame
|
41 |
+
if not json_data:
|
42 |
+
json_data = [{"Note": "No allocation data available"}]
|
43 |
+
|
44 |
+
df = pd.json_normalize(json_data)
|
45 |
+
markdown_table = df.to_markdown(index=False)
|
46 |
+
print(f"[DEBUG] json_to_table output:\n{markdown_table}")
|
47 |
+
|
48 |
+
return markdown_table
|
49 |
+
|
50 |
+
def rag_tool(input_data: RagToolInput):
|
51 |
+
"""Lets the agent use RAG system as a tool"""
|
52 |
+
query = input_data.query
|
53 |
+
|
54 |
+
embedding_model = OpenAIEmbeddings(
|
55 |
+
model="text-embedding-3-small",
|
56 |
+
dimensions=384
|
57 |
+
)
|
58 |
+
kb = PineconeVectorStore(
|
59 |
+
pinecone_api_key=os.environ.get('PINCEONE_API_KEY'),
|
60 |
+
index_name='rag-rubic',
|
61 |
+
namespace='vectors_lightmodel'
|
62 |
+
)
|
63 |
+
retriever = kb.as_retriever(search_kwargs={"k": 10})
|
64 |
+
context = retriever.invoke(query)
|
65 |
+
return "\n".join([doc.page_content for doc in context])
|
66 |
+
|
67 |
+
@tool
|
68 |
+
def goal_feasibility(goal_amount: float, timeline: float, current_savings: float, income : float) -> dict:
|
69 |
+
"""Evaluate if a financial goal is feasible based on user income, timeline, and savings. Use when user asks about goal feasibility."""
|
70 |
+
# Input checks
|
71 |
+
if timeline <= 0:
|
72 |
+
return {
|
73 |
+
"feasible": False,
|
74 |
+
"status": "Invalid",
|
75 |
+
"monthly_required": 0,
|
76 |
+
"reason": "Timeline must be greater than 0 months."
|
77 |
+
}
|
78 |
+
|
79 |
+
# Calculate the remaining amount
|
80 |
+
remaining_amount = goal_amount - current_savings
|
81 |
+
if remaining_amount <= 0:
|
82 |
+
return {
|
83 |
+
"feasible": True,
|
84 |
+
"status": "Already Achieved",
|
85 |
+
"monthly_required": 0,
|
86 |
+
"reason": "You have already met or exceeded your savings goal."
|
87 |
+
}
|
88 |
+
|
89 |
+
monthly_required = remaining_amount / timeline
|
90 |
+
income_ratio = monthly_required / income
|
91 |
+
|
92 |
+
# Feasibility classification
|
93 |
+
if income_ratio <= 0.3:
|
94 |
+
status = "Feasible"
|
95 |
+
feasible = True
|
96 |
+
reason = "The required savings per month is manageable for an average income."
|
97 |
+
elif income_ratio <= 0.7:
|
98 |
+
status = "Difficult"
|
99 |
+
feasible = False
|
100 |
+
reason = "The required monthly saving is high but may be possible with strict budgeting."
|
101 |
+
else:
|
102 |
+
status = "Infeasible"
|
103 |
+
feasible = False
|
104 |
+
reason = "The required monthly saving is unrealistic for an average income."
|
105 |
+
|
106 |
+
return {
|
107 |
+
"feasible": feasible,
|
108 |
+
"status": status,
|
109 |
+
"monthly_required": round(monthly_required, 2),
|
110 |
+
"reason": reason
|
111 |
+
}
|
112 |
+
|
113 |
+
|
114 |
+
@tool
|
115 |
+
def save_data(new_user_data:dict, new_alloc_data:dict):
|
116 |
+
"Saves the updated user_data and allocations data in a json file."
|
117 |
+
path = os.getenv("DATA_PATH", ".")
|
118 |
+
save_path = os.path.join(path, "updated_json")
|
119 |
+
os.makedirs(save_path, exist_ok=True)
|
120 |
+
with open(os.path.join(save_path, "updated_user_data.json"), "w") as f:
|
121 |
+
json.dump(new_user_data, f, indent=2)
|
122 |
+
|
123 |
+
with open(os.path.join(save_path, "updated_allocations.json"), "w") as f:
|
124 |
+
json.dump(new_alloc_data, f, indent=2)
|
125 |
+
|