WebashalarForML commited on
Commit
e6ce96f
·
verified ·
1 Parent(s): 5a386b0

Create demo.py

Browse files
Files changed (1) hide show
  1. demo.py +357 -0
demo.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, render_template, request, redirect, url_for, send_from_directory, flash
2
+ from flask_socketio import SocketIO
3
+ import threading
4
+ import os
5
+ from dotenv import load_dotenv
6
+ import sqlite3
7
+ from werkzeug.utils import secure_filename
8
+
9
+ # LangChain and agent imports
10
+ from langchain_community.chat_models.huggingface import ChatHuggingFace # if needed later
11
+ from langchain.agents import Tool
12
+ from langchain.agents.format_scratchpad import format_log_to_str
13
+ from langchain.agents.output_parsers import ReActJsonSingleInputOutputParser
14
+ from langchain_core.callbacks import CallbackManager, BaseCallbackHandler
15
+ from langchain_community.agent_toolkits.load_tools import load_tools
16
+ from langchain_core.tools import tool
17
+ from langchain_community.agent_toolkits import PowerBIToolkit
18
+ from langchain.chains import LLMMathChain
19
+ from langchain import hub
20
+ from langchain_community.tools import DuckDuckGoSearchRun
21
+
22
+ # Agent requirements and type hints
23
+ from typing import Annotated, Literal, TypedDict, Any
24
+ from langchain_core.messages import AIMessage, ToolMessage
25
+ from pydantic import BaseModel, Field
26
+ from typing_extensions import TypedDict
27
+ from langgraph.graph import END, StateGraph, START
28
+ from langgraph.graph.message import AnyMessage, add_messages
29
+ from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks
30
+ from langgraph.prebuilt import ToolNode
31
+
32
+ import traceback
33
+
34
+ # Load environment variables
35
+ load_dotenv()
36
+
37
+
38
+ # Global configuration variables
39
+ UPLOAD_FOLDER = os.path.join(os.getcwd(), "uploads")
40
+ BASE_DIR = os.path.abspath(os.path.dirname(__file__))
41
+ DATABASE_URI = f"sqlite:///{os.path.join(BASE_DIR, 'data', 'mydb.db')}"
42
+ print("DATABASE URI:", DATABASE_URI)
43
+
44
+ # API Keys from .env file
45
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
46
+ MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
47
+ os.environ["GROQ_API_KEY"] = GROQ_API_KEY
48
+ os.environ["MISTRAL_API_KEY"] = MISTRAL_API_KEY
49
+
50
+ # Global variables for dynamic agent and DB file path; initially None.
51
+ agent_app = None
52
+ abs_file_path = None
53
+ db_path = None
54
+
55
+ print(traceback.format_exc())
56
+
57
+ # =============================================================================
58
+ # create_agent_app: Given a database path, initialize the agent workflow.
59
+ # =============================================================================
60
+
61
+ def create_agent_app(db_path: str):
62
+ # Use ChatGroq as our LLM here; you can swap to ChatMistralAI if preferred.
63
+ from langchain_groq import ChatGroq
64
+ llm = ChatGroq(model="llama3-70b-8192")
65
+
66
+ # -------------------------------------------------------------------------
67
+ # Define a tool for executing SQL queries.
68
+ # -------------------------------------------------------------------------
69
+ @tool
70
+ def db_query_tool(query: str) -> str:
71
+ """
72
+ Executes a SQL query on the connected SQLite database.
73
+
74
+ Parameters:
75
+ query (str): A SQL query string to be executed.
76
+
77
+ Returns:
78
+ str: The result from the database if successful, or an error message if not.
79
+ """
80
+ result = db_instance.run_no_throw(query)
81
+ return result if result else "Error: Query failed. Please rewrite your query and try again."
82
+
83
+ # -------------------------------------------------------------------------
84
+ # Pydantic model for final answer
85
+ # -------------------------------------------------------------------------
86
+ class SubmitFinalAnswer(BaseModel):
87
+ final_answer: str = Field(..., description="The final answer to the user")
88
+
89
+ # -------------------------------------------------------------------------
90
+ # Define state type for our workflow.
91
+ # -------------------------------------------------------------------------
92
+ class State(TypedDict):
93
+ messages: Annotated[list[AnyMessage], add_messages]
94
+
95
+ # -------------------------------------------------------------------------
96
+ # Set up prompt templates (using langchain_core.prompts) for query checking
97
+ # and query generation.
98
+ # -------------------------------------------------------------------------
99
+ from langchain_core.prompts import ChatPromptTemplate
100
+
101
+ query_check_system = (
102
+ "You are a SQL expert with a strong attention to detail.\n"
103
+ "Double check the SQLite query for common mistakes, including:\n"
104
+ "- Using NOT IN with NULL values\n"
105
+ "- Using UNION when UNION ALL should have been used\n"
106
+ "- Using BETWEEN for exclusive ranges\n"
107
+ "- Data type mismatch in predicates\n"
108
+ "- Properly quoting identifiers\n"
109
+ "- Using the correct number of arguments for functions\n"
110
+ "- Casting to the correct data type\n"
111
+ "- Using the proper columns for joins\n\n"
112
+ "If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.\n"
113
+ "You will call the appropriate tool to execute the query after running this check."
114
+ )
115
+ query_check_prompt = ChatPromptTemplate.from_messages([
116
+ ("system", query_check_system),
117
+ ("placeholder", "{messages}")
118
+ ])
119
+ query_check = query_check_prompt | llm.bind_tools([db_query_tool])
120
+
121
+ query_gen_system = (
122
+ "You are a SQL expert with a strong attention to detail.\n\n"
123
+ "Given an input question, output a syntactically correct SQLite query to run, then look at the results of the query and return the answer.\n\n"
124
+ "DO NOT call any tool besides SubmitFinalAnswer to submit the final answer.\n\n"
125
+ "When generating the query:\n"
126
+ "Output the SQL query that answers the input question without a tool call.\n"
127
+ "Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.\n"
128
+ "You can order the results by a relevant column to return the most interesting examples in the database.\n"
129
+ "Never query for all the columns from a specific table, only ask for the relevant columns given the question.\n\n"
130
+ "If you get an error while executing a query, rewrite the query and try again.\n"
131
+ "If you get an empty result set, you should try to rewrite the query to get a non-empty result set.\n"
132
+ "NEVER make stuff up if you don't have enough information to answer the query... just say you don't have enough information.\n\n"
133
+ "If you have enough information to answer the input question, simply invoke the appropriate tool to submit the final answer to the user.\n"
134
+ "DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. Do not return any SQL query except answer."
135
+ )
136
+ query_gen_prompt = ChatPromptTemplate.from_messages([
137
+ ("system", query_gen_system),
138
+ ("placeholder", "{messages}")
139
+ ])
140
+ query_gen = query_gen_prompt | llm.bind_tools([SubmitFinalAnswer])
141
+
142
+ # -------------------------------------------------------------------------
143
+ # Update database URI and file path, create SQLDatabase connection.
144
+ # -------------------------------------------------------------------------
145
+
146
+ abs_db_path_local = os.path.abspath(db_path)
147
+ global DATABASE_URI
148
+ DATABASE_URI = abs_db_path_local
149
+ db_uri = f"sqlite:///{abs_db_path_local}"
150
+ print("db_uri", db_uri)
151
+ # Uncomment if flash is needed; ensure you have flask.flash imported if so.
152
+ # flash(f"db_uri:{db_uri}", "warning")
153
+
154
+ from langchain_community.utilities import SQLDatabase
155
+ db_instance = SQLDatabase.from_uri(db_uri)
156
+ print("db_instance----->", db_instance)
157
+ # flash(f"db_instance:{db_instance}", "warning")
158
+
159
+ # -------------------------------------------------------------------------
160
+ # Create SQL toolkit.
161
+ # -------------------------------------------------------------------------
162
+
163
+ from langchain_community.agent_toolkits import SQLDatabaseToolkit
164
+ toolkit_instance = SQLDatabaseToolkit(db=db_instance, llm=llm)
165
+ tools_instance = toolkit_instance.get_tools()
166
+
167
+ # -------------------------------------------------------------------------
168
+ # Define workflow nodes and fallback functions.
169
+ # -------------------------------------------------------------------------
170
+
171
+ def first_tool_call(state: State) -> dict[str, list[AIMessage]]:
172
+ return {"messages": [AIMessage(content="", tool_calls=[{"name": "sql_db_list_tables", "args": {}, "id": "tool_abcd123"}])]}
173
+
174
+ def handle_tool_error(state: State) -> dict:
175
+ error = state.get("error")
176
+ tool_calls = state["messages"][-1].tool_calls
177
+ return {"messages": [
178
+ ToolMessage(content=f"Error: {repr(error)}. Please fix your mistakes.", tool_call_id=tc["id"])
179
+ for tc in tool_calls
180
+ ]}
181
+
182
+ def create_tool_node_with_fallback(tools_list: list) -> RunnableWithFallbacks[Any, dict]:
183
+ return ToolNode(tools_list).with_fallbacks([RunnableLambda(handle_tool_error)], exception_key="error")
184
+
185
+ def query_gen_node(state: State):
186
+ message = query_gen.invoke(state)
187
+ tool_messages = []
188
+ if message.tool_calls:
189
+ for tc in message.tool_calls:
190
+ if tc["name"] != "SubmitFinalAnswer":
191
+ tool_messages.append(ToolMessage(
192
+ content=f"Error: The wrong tool was called: {tc['name']}. Please fix your mistakes.",
193
+ tool_call_id=tc["id"]
194
+ ))
195
+ return {"messages": [message] + tool_messages}
196
+
197
+ def should_continue(state: State) -> Literal[END, "correct_query", "query_gen"]:
198
+ messages = state["messages"]
199
+ last_message = messages[-1]
200
+ if getattr(last_message, "tool_calls", None):
201
+ return END
202
+ if last_message.content.startswith("Error:"):
203
+ return "query_gen"
204
+ return "correct_query"
205
+
206
+ def model_check_query(state: State) -> dict[str, list[AIMessage]]:
207
+ return {"messages": [query_check.invoke({"messages": [state["messages"][-1]]})]}
208
+
209
+ # -------------------------------------------------------------------------
210
+ # Get tools for listing tables and fetching schema.
211
+ # -------------------------------------------------------------------------
212
+
213
+ list_tables_tool = next((tool for tool in tools_instance if tool.name == "sql_db_list_tables"), None)
214
+ get_schema_tool = next((tool for tool in tools_instance if tool.name == "sql_db_schema"), None)
215
+
216
+ workflow = StateGraph(State)
217
+ workflow.add_node("first_tool_call", first_tool_call)
218
+ workflow.add_node("list_tables_tool", create_tool_node_with_fallback([list_tables_tool]))
219
+ workflow.add_node("get_schema_tool", create_tool_node_with_fallback([get_schema_tool]))
220
+ model_get_schema = llm.bind_tools([get_schema_tool])
221
+ workflow.add_node("model_get_schema", lambda state: {"messages": [model_get_schema.invoke(state["messages"])],})
222
+ workflow.add_node("query_gen", query_gen_node)
223
+ workflow.add_node("correct_query", model_check_query)
224
+ workflow.add_node("execute_query", create_tool_node_with_fallback([db_query_tool]))
225
+
226
+ workflow.add_edge(START, "first_tool_call")
227
+ workflow.add_edge("first_tool_call", "list_tables_tool")
228
+ workflow.add_edge("list_tables_tool", "model_get_schema")
229
+ workflow.add_edge("model_get_schema", "get_schema_tool")
230
+ workflow.add_edge("get_schema_tool", "query_gen")
231
+ workflow.add_conditional_edges("query_gen", should_continue)
232
+ workflow.add_edge("correct_query", "execute_query")
233
+ workflow.add_edge("execute_query", "query_gen")
234
+
235
+ # Return compiled workflow
236
+ return workflow.compile()
237
+
238
+
239
+ # =============================================================================
240
+ # create_app: The application factory.
241
+ # =============================================================================
242
+
243
+ def create_app():
244
+ flask_app = Flask(__name__, static_url_path='/uploads', static_folder='uploads')
245
+ socketio = SocketIO(flask_app, cors_allowed_origins="*")
246
+
247
+ # Ensure uploads folder exists.
248
+ if not os.path.exists(UPLOAD_FOLDER):
249
+ os.makedirs(UPLOAD_FOLDER)
250
+ flask_app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
251
+
252
+ # -------------------------------------------------------------------------
253
+ # Serve uploaded files via a custom route.
254
+ # -------------------------------------------------------------------------
255
+
256
+ @flask_app.route("/files/<path:filename>")
257
+ def uploaded_file(filename):
258
+ return send_from_directory(flask_app.config['UPLOAD_FOLDER'], filename)
259
+
260
+ # -------------------------------------------------------------------------
261
+ # Helper: run_agent runs the agent with the given prompt.
262
+ # -------------------------------------------------------------------------
263
+
264
+ def run_agent(prompt, socketio):
265
+ global agent_app, abs_file_path, db_path
266
+ if not abs_file_path:
267
+ socketio.emit("log", {"message": "[ERROR]: No DB file uploaded."})
268
+ socketio.emit("final", {"message": "No database available. Please upload one and try again."})
269
+ return
270
+
271
+ try:
272
+ # Lazy agent initialization: use the previously uploaded DB.
273
+ if agent_app is None:
274
+ print("[INFO]: Initializing agent for the first time...")
275
+ agent_app = create_agent_app(abs_file_path)
276
+ socketio.emit("log", {"message": "[INFO]: Agent initialized."})
277
+
278
+ query = {"messages": [("user", prompt)]}
279
+ result = agent_app.invoke(query)
280
+ try:
281
+ result = result["messages"][-1].tool_calls[0]["args"]["final_answer"]
282
+ except Exception:
283
+ result = "Query failed or no valid answer found."
284
+
285
+ print("final_answer------>", result)
286
+ socketio.emit("final", {"message": result})
287
+ except Exception as e:
288
+ print(f"[ERROR]: {str(e)}")
289
+ socketio.emit("log", {"message": f"[ERROR]: {str(e)}"})
290
+ socketio.emit("final", {"message": "Generation failed."})
291
+
292
+ # -------------------------------------------------------------------------
293
+ # Route: index page.
294
+ # -------------------------------------------------------------------------
295
+
296
+ @flask_app.route("/")
297
+ def index():
298
+ return render_template("index.html")
299
+
300
+ # -------------------------------------------------------------------------
301
+ # Route: generate (POST) – receives a prompt and runs the agent.
302
+ # -------------------------------------------------------------------------
303
+
304
+ @flask_app.route("/generate", methods=["POST"])
305
+ def generate():
306
+ try:
307
+ socketio.emit("log", {"message": "[STEP]: Entering query_gen..."})
308
+ data = request.json
309
+ prompt = data.get("prompt", "")
310
+ socketio.emit("log", {"message": f"[INFO]: Received prompt: {prompt}"})
311
+ thread = threading.Thread(target=run_agent, args=(prompt, socketio))
312
+ socketio.emit("log", {"message": f"[INFO]: Starting thread: {thread}"})
313
+ thread.start()
314
+ return "OK", 200
315
+ except Exception as e:
316
+ print(f"[ERROR]: {str(e)}")
317
+ socketio.emit("log", {"message": f"[ERROR]: {str(e)}"})
318
+ return "ERROR", 500
319
+
320
+ # -------------------------------------------------------------------------
321
+ # Route: upload (GET/POST) – handles uploading the SQLite DB file.
322
+ # -------------------------------------------------------------------------
323
+
324
+ @flask_app.route("/upload", methods=["GET", "POST"])
325
+ def upload():
326
+ global abs_file_path, agent_app, db_path
327
+ try:
328
+ if request.method == "POST":
329
+ file = request.files.get("file")
330
+ if not file:
331
+ print("No file uploaded")
332
+ return "No file uploaded", 400
333
+ filename = secure_filename(file.filename)
334
+ if filename.endswith('.db'):
335
+ db_path = os.path.join(flask_app.config['UPLOAD_FOLDER'], "uploaded.db")
336
+ print("Saving file to:", db_path)
337
+ file.save(db_path)
338
+ abs_file_path = os.path.abspath(db_path) # Save it here; agent init will occur on first query.
339
+ print(f"[INFO]: File '{filename}' uploaded. Agent will be initialized on first query.")
340
+ socketio.emit("log", {"message": f"[INFO]: Database file '{filename}' uploaded."})
341
+ return redirect(url_for("index"))
342
+ return render_template("upload.html")
343
+ except Exception as e:
344
+ print(f"[ERROR]: {str(e)}")
345
+ socketio.emit("log", {"message": f"[ERROR]: {str(e)}"})
346
+ return render_template("upload.html")
347
+
348
+ return flask_app, socketio
349
+
350
+ # =============================================================================
351
+ # Create the app for Gunicorn compatibility.
352
+ # =============================================================================
353
+
354
+ app, socketio_instance = create_app()
355
+
356
+ if __name__ == "__main__":
357
+ socketio_instance.run(app, debug=True)