doc_db_integration
#32
by
nolanzandi
- opened
- app.py +3 -1
- data_sources/__init__.py +2 -1
- data_sources/connect_doc_db.py +36 -0
- data_sources/connect_sql_db.py +1 -1
- functions/__init__.py +4 -4
- functions/chat_functions.py +79 -0
- functions/query_functions.py +72 -0
- requirements.txt +3 -0
- templates/doc_db.py +94 -0
- tools/tools.py +36 -1
app.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from utils import TEMP_DIR, message_dict
|
2 |
import gradio as gr
|
3 |
-
import templates.data_file as data_file, templates.sql_db as sql_db
|
4 |
|
5 |
import os
|
6 |
from getpass import getpass
|
@@ -74,6 +74,8 @@ with gr.Blocks(theme=theme, css=css, head=head, delete_cache=(3600,3600)) as dem
|
|
74 |
data_file.demo.render()
|
75 |
with gr.Tab("SQL Database"):
|
76 |
sql_db.demo.render()
|
|
|
|
|
77 |
|
78 |
footer = gr.HTML("""<!-- Footer -->
|
79 |
<footer class="max-w-4xl mx-auto mt-12 text-center text-gray-500 text-sm">
|
|
|
1 |
from utils import TEMP_DIR, message_dict
|
2 |
import gradio as gr
|
3 |
+
import templates.data_file as data_file, templates.sql_db as sql_db, templates.doc_db as doc_db
|
4 |
|
5 |
import os
|
6 |
from getpass import getpass
|
|
|
74 |
data_file.demo.render()
|
75 |
with gr.Tab("SQL Database"):
|
76 |
sql_db.demo.render()
|
77 |
+
with gr.Tab("Document (MongoDB) Database"):
|
78 |
+
doc_db.demo.render()
|
79 |
|
80 |
footer = gr.HTML("""<!-- Footer -->
|
81 |
<footer class="max-w-4xl mx-auto mt-12 text-center text-gray-500 text-sm">
|
data_sources/__init__.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
from .upload_file import process_data_upload
|
2 |
from .connect_sql_db import connect_sql_db
|
|
|
3 |
|
4 |
-
__all__ = ["process_data_upload","connect_sql_db"]
|
|
|
1 |
from .upload_file import process_data_upload
|
2 |
from .connect_sql_db import connect_sql_db
|
3 |
+
from .connect_doc_db import connect_doc_db
|
4 |
|
5 |
+
__all__ = ["process_data_upload","connect_sql_db","connect_doc_db"]
|
data_sources/connect_doc_db.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pymongo import MongoClient
|
2 |
+
import os
|
3 |
+
from utils import TEMP_DIR
|
4 |
+
from pymongo_schema.extract import extract_pymongo_client_schema
|
5 |
+
|
6 |
+
def connect_doc_db(connection_string, nosql_db_name, session_hash):
|
7 |
+
try:
|
8 |
+
# Create a MongoClient object
|
9 |
+
client = MongoClient(connection_string)
|
10 |
+
print("Connected to NoSQL Mongo DB")
|
11 |
+
|
12 |
+
# Access a database
|
13 |
+
db = client[nosql_db_name]
|
14 |
+
|
15 |
+
collection_names = db.list_collection_names()
|
16 |
+
|
17 |
+
print(collection_names)
|
18 |
+
|
19 |
+
schema = extract_pymongo_client_schema(client)
|
20 |
+
|
21 |
+
# Close the connection
|
22 |
+
if 'client' in locals() and client:
|
23 |
+
client.close()
|
24 |
+
print("MongoDB Connection closed.")
|
25 |
+
|
26 |
+
session_path = 'doc_db'
|
27 |
+
|
28 |
+
dir_path = TEMP_DIR / str(session_hash) / str(session_path)
|
29 |
+
os.makedirs(dir_path, exist_ok=True)
|
30 |
+
|
31 |
+
return ["success","<p style='color:green;text-align:center;font-size:18px;'>Document database connected successful</p>", collection_names, schema]
|
32 |
+
except Exception as e:
|
33 |
+
print("DocDB CONNECTION ERROR")
|
34 |
+
print(e)
|
35 |
+
return ["error",f"<p style='color:red;text-align:center;font-size:18px;font-weight:bold;'>ERROR: {e}</p>"]
|
36 |
+
|
data_sources/connect_sql_db.py
CHANGED
@@ -36,7 +36,7 @@ def connect_sql_db(url, sql_user, sql_port, sql_pass, sql_db_name, session_hash)
|
|
36 |
|
37 |
return ["success","<p style='color:green;text-align:center;font-size:18px;'>SQL database connected successful</p>", table_names]
|
38 |
except Exception as e:
|
39 |
-
print("
|
40 |
print(e)
|
41 |
return ["error",f"<p style='color:red;text-align:center;font-size:18px;font-weight:bold;'>ERROR: {e}</p>"]
|
42 |
|
|
|
36 |
|
37 |
return ["success","<p style='color:green;text-align:center;font-size:18px;'>SQL database connected successful</p>", table_names]
|
38 |
except Exception as e:
|
39 |
+
print("SQL DB CONNECTION ERROR")
|
40 |
print(e)
|
41 |
return ["error",f"<p style='color:red;text-align:center;font-size:18px;font-weight:bold;'>ERROR: {e}</p>"]
|
42 |
|
functions/__init__.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
-
from .query_functions import SQLiteQuery, sqlite_query_func, PostgreSQLQuery, sql_query_func
|
2 |
from .chart_functions import table_generation_func, scatter_chart_generation_func, \
|
3 |
line_chart_generation_func, bar_chart_generation_func, pie_chart_generation_func, histogram_generation_func, scatter_chart_fig
|
4 |
-
from .chat_functions import sql_example_question_generator, example_question_generator, chatbot_with_fc, sql_chatbot_with_fc
|
5 |
from .stat_functions import regression_func
|
6 |
|
7 |
-
__all__ = ["SQLiteQuery","sqlite_query_func","sql_query_func","table_generation_func","scatter_chart_generation_func",
|
8 |
"line_chart_generation_func","bar_chart_generation_func","regression_func", "pie_chart_generation_func", "histogram_generation_func",
|
9 |
-
"scatter_chart_fig","sql_example_question_generator","example_question_generator","chatbot_with_fc","sql_chatbot_with_fc"]
|
|
|
1 |
+
from .query_functions import SQLiteQuery, sqlite_query_func, PostgreSQLQuery, sql_query_func, doc_db_query_func
|
2 |
from .chart_functions import table_generation_func, scatter_chart_generation_func, \
|
3 |
line_chart_generation_func, bar_chart_generation_func, pie_chart_generation_func, histogram_generation_func, scatter_chart_fig
|
4 |
+
from .chat_functions import sql_example_question_generator, example_question_generator, doc_db_example_question_generator, chatbot_with_fc, sql_chatbot_with_fc, doc_db_chatbot_with_fc
|
5 |
from .stat_functions import regression_func
|
6 |
|
7 |
+
__all__ = ["SQLiteQuery","sqlite_query_func","sql_query_func","doc_db_query_func","table_generation_func","scatter_chart_generation_func",
|
8 |
"line_chart_generation_func","bar_chart_generation_func","regression_func", "pie_chart_generation_func", "histogram_generation_func",
|
9 |
+
"scatter_chart_fig","doc_db_example_question_generator","sql_example_question_generator","example_question_generator","chatbot_with_fc","sql_chatbot_with_fc","doc_db_chatbot_with_fc"]
|
functions/chat_functions.py
CHANGED
@@ -57,6 +57,26 @@ def sql_example_question_generator(session_hash, db_tables, db_name):
|
|
57 |
|
58 |
return example_response["replies"][0].text
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
def chatbot_with_fc(message, history, session_hash):
|
61 |
from functions import sqlite_query_func, table_generation_func, regression_func, scatter_chart_generation_func, \
|
62 |
line_chart_generation_func,bar_chart_generation_func,pie_chart_generation_func,histogram_generation_func
|
@@ -170,4 +190,63 @@ def sql_chatbot_with_fc(message, history, session_hash, db_url, db_port, db_user
|
|
170 |
message_dict[session_hash]['sql'].append(response["replies"][0])
|
171 |
break
|
172 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
173 |
return response["replies"][0].text
|
|
|
57 |
|
58 |
return example_response["replies"][0].text
|
59 |
|
60 |
+
def doc_db_example_question_generator(session_hash, db_collections, db_name, db_schema):
|
61 |
+
example_response = None
|
62 |
+
example_messages = [
|
63 |
+
ChatMessage.from_system(
|
64 |
+
f"You are a helpful and knowledgeable agent who has access to an MongoDB NoSQL document database called {db_name}."
|
65 |
+
)
|
66 |
+
]
|
67 |
+
|
68 |
+
example_messages.append(ChatMessage.from_user(text=f"""We have a MongoDB NoSQL document database with the following collections: {db_collections}.
|
69 |
+
The schema of these collections is: {db_schema}.
|
70 |
+
We also have an AI agent with access to the same database that will be performing data analysis.
|
71 |
+
Please return an array of seven strings, each one being a question for our data analysis agent
|
72 |
+
that we can suggest that you believe will be insightful or helpful to a data analysis looking for
|
73 |
+
data insights. Return nothing more than the array of questions because I need that specific data structure
|
74 |
+
to process your response. No other response type or data structure will work."""))
|
75 |
+
|
76 |
+
example_response = chat_generator.run(messages=example_messages)
|
77 |
+
|
78 |
+
return example_response["replies"][0].text
|
79 |
+
|
80 |
def chatbot_with_fc(message, history, session_hash):
|
81 |
from functions import sqlite_query_func, table_generation_func, regression_func, scatter_chart_generation_func, \
|
82 |
line_chart_generation_func,bar_chart_generation_func,pie_chart_generation_func,histogram_generation_func
|
|
|
190 |
message_dict[session_hash]['sql'].append(response["replies"][0])
|
191 |
break
|
192 |
|
193 |
+
return response["replies"][0].text
|
194 |
+
|
195 |
+
def doc_db_chatbot_with_fc(message, history, session_hash, db_connection_string, db_name, db_collections, db_schema):
|
196 |
+
from functions import doc_db_query_func, table_generation_func, regression_func, scatter_chart_generation_func, \
|
197 |
+
line_chart_generation_func,bar_chart_generation_func,pie_chart_generation_func,histogram_generation_func
|
198 |
+
import tools.tools as tools
|
199 |
+
|
200 |
+
available_functions = {"doc_db_query_func": doc_db_query_func,"table_generation_func":table_generation_func,
|
201 |
+
"line_chart_generation_func":line_chart_generation_func,"bar_chart_generation_func":bar_chart_generation_func,
|
202 |
+
"scatter_chart_generation_func":scatter_chart_generation_func, "pie_chart_generation_func":pie_chart_generation_func,
|
203 |
+
"histogram_generation_func":histogram_generation_func,
|
204 |
+
"regression_func":regression_func }
|
205 |
+
|
206 |
+
if message_dict[session_hash]['doc_db'] != None:
|
207 |
+
message_dict[session_hash]['doc_db'].append(ChatMessage.from_user(message))
|
208 |
+
else:
|
209 |
+
messages = [
|
210 |
+
ChatMessage.from_system(
|
211 |
+
f"""You are a helpful and knowledgeable agent who has access to an NoSQL MongoDB Document database which has a series of collections called {db_collections}.
|
212 |
+
The schema of these collections is: {db_schema}.
|
213 |
+
You also have access to a function, called table_generation_func, that can take a query.csv file generated from our sql query and returns an iframe that we should display in our chat window.
|
214 |
+
You also have access to a scatter plot function, called scatter_chart_generation_func, that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate a scatter plot and returns an iframe that we should display in our chat window.
|
215 |
+
You also have access to a line chart function, called line_chart_generation_func, that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate a line chart and returns an iframe that we should display in our chat window.
|
216 |
+
You also have access to a bar graph function, called line_chart_generation_func, that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate a bar graph and returns an iframe that we should display in our chat window.
|
217 |
+
You also have access to a pie chart function, called pie_chart_generation_func, that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate a pie chart and returns an iframe that we should display in our chat window.
|
218 |
+
You also have access to a histogram function, called histogram_generation_func, that can take a query.csv file generated from our sql query and uses plotly dictionaries to generate a histogram and returns an iframe that we should display in our chat window.
|
219 |
+
You also have access to a linear regression function, called regression_func, that can take a query.csv file generated from our sql query and a list of column names for our independent and dependent variables and return a regression data string and a regression chart which is returned as an iframe.
|
220 |
+
Could you please always display the generated charts, tables, and visualizations as part of your output?"""
|
221 |
+
)
|
222 |
+
]
|
223 |
+
messages.append(ChatMessage.from_user(message))
|
224 |
+
message_dict[session_hash]['doc_db'] = messages
|
225 |
+
|
226 |
+
response = chat_generator.run(messages=message_dict[session_hash]['doc_db'], generation_kwargs={"tools": tools.doc_db_tools_call(db_collections)})
|
227 |
+
|
228 |
+
while True:
|
229 |
+
# if OpenAI response is a tool call
|
230 |
+
if response and response["replies"][0].meta["finish_reason"] == "tool_calls" or response["replies"][0].tool_calls:
|
231 |
+
function_calls = response["replies"][0].tool_calls
|
232 |
+
for function_call in function_calls:
|
233 |
+
message_dict[session_hash]['doc_db'].append(ChatMessage.from_assistant(tool_calls=[function_call]))
|
234 |
+
## Parse function calling information
|
235 |
+
function_name = function_call.tool_name
|
236 |
+
function_args = function_call.arguments
|
237 |
+
|
238 |
+
## Find the corresponding function and call it with the given arguments
|
239 |
+
function_to_call = available_functions[function_name]
|
240 |
+
function_response = function_to_call(**function_args, session_hash=session_hash, connection_string=db_connection_string,
|
241 |
+
doc_db_name=db_name, session_folder='doc_db')
|
242 |
+
print(function_name)
|
243 |
+
## Append function response to the messages list using `ChatMessage.from_tool`
|
244 |
+
message_dict[session_hash]['doc_db'].append(ChatMessage.from_tool(tool_result=function_response['reply'], origin=function_call))
|
245 |
+
response = chat_generator.run(messages=message_dict[session_hash]['doc_db'], generation_kwargs={"tools": tools.doc_db_tools_call(db_collections)})
|
246 |
+
|
247 |
+
# Regular Conversation
|
248 |
+
else:
|
249 |
+
message_dict[session_hash]['doc_db'].append(response["replies"][0])
|
250 |
+
break
|
251 |
+
|
252 |
return response["replies"][0].text
|
functions/query_functions.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
from typing import List
|
|
|
2 |
from haystack import component
|
3 |
import pandas as pd
|
4 |
pd.set_option('display.max_rows', None)
|
@@ -7,7 +8,10 @@ pd.set_option('display.width', None)
|
|
7 |
pd.set_option('display.max_colwidth', None)
|
8 |
import sqlite3
|
9 |
import psycopg2
|
|
|
|
|
10 |
from utils import TEMP_DIR
|
|
|
11 |
|
12 |
@component
|
13 |
class SQLiteQuery:
|
@@ -93,3 +97,71 @@ def sql_query_func(queries: List[str], session_hash, db_url, db_port, db_user, d
|
|
93 |
"""
|
94 |
print(reply)
|
95 |
return {"reply": reply}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from typing import List
|
2 |
+
from typing import AnyStr
|
3 |
from haystack import component
|
4 |
import pandas as pd
|
5 |
pd.set_option('display.max_rows', None)
|
|
|
8 |
pd.set_option('display.max_colwidth', None)
|
9 |
import sqlite3
|
10 |
import psycopg2
|
11 |
+
from pymongo import MongoClient
|
12 |
+
import pymongoarrow.monkey
|
13 |
from utils import TEMP_DIR
|
14 |
+
import ast
|
15 |
|
16 |
@component
|
17 |
class SQLiteQuery:
|
|
|
97 |
"""
|
98 |
print(reply)
|
99 |
return {"reply": reply}
|
100 |
+
|
101 |
+
@component
|
102 |
+
class DocDBQuery:
|
103 |
+
|
104 |
+
def __init__(self, connection_string: str, doc_db_name: str):
|
105 |
+
client = MongoClient(connection_string)
|
106 |
+
|
107 |
+
self.client = client
|
108 |
+
self.connection = client[doc_db_name]
|
109 |
+
|
110 |
+
@component.output_types(results=List[str], queries=List[str])
|
111 |
+
def run(self, aggregation_pipeline: List[str], db_collection, session_hash):
|
112 |
+
pymongoarrow.monkey.patch_all()
|
113 |
+
print("ATTEMPTING TO RUN MONGODB QUERY")
|
114 |
+
dir_path = TEMP_DIR / str(session_hash)
|
115 |
+
results = []
|
116 |
+
print(aggregation_pipeline)
|
117 |
+
|
118 |
+
aggregation_pipeline = aggregation_pipeline.replace(" ", "")
|
119 |
+
|
120 |
+
false_replace = [':false', ': false']
|
121 |
+
false_value = ':False'
|
122 |
+
true_replace = [':true', ': true']
|
123 |
+
true_value = ':True'
|
124 |
+
|
125 |
+
for replace in false_replace:
|
126 |
+
aggregation_pipeline = aggregation_pipeline.replace(replace, false_value)
|
127 |
+
for replace in true_replace:
|
128 |
+
aggregation_pipeline = aggregation_pipeline.replace(replace, true_value)
|
129 |
+
|
130 |
+
query_list = ast.literal_eval(aggregation_pipeline)
|
131 |
+
|
132 |
+
print("QUERY List")
|
133 |
+
print(query_list)
|
134 |
+
print(db_collection)
|
135 |
+
|
136 |
+
db = self.connection
|
137 |
+
collection = db[db_collection]
|
138 |
+
|
139 |
+
print(collection)
|
140 |
+
docs = collection.aggregate_pandas_all(query_list)
|
141 |
+
print("DATA FRAME COMPLETE")
|
142 |
+
docs.to_csv(f'{dir_path}/doc_db/query.csv', index=False)
|
143 |
+
print("CSV COMPLETE")
|
144 |
+
results.append(f"{docs}")
|
145 |
+
self.client.close()
|
146 |
+
return {"results": results, "queries": aggregation_pipeline}
|
147 |
+
|
148 |
+
|
149 |
+
|
150 |
+
def doc_db_query_func(aggregation_pipeline: List[str], db_collection: AnyStr, session_hash, connection_string, doc_db_name, **kwargs):
|
151 |
+
doc_db_query = DocDBQuery(connection_string, doc_db_name)
|
152 |
+
try:
|
153 |
+
result = doc_db_query.run(aggregation_pipeline, db_collection, session_hash)
|
154 |
+
print("RESULT")
|
155 |
+
if len(result["results"][0]) > 1000:
|
156 |
+
print("QUERY TOO LARGE")
|
157 |
+
return {"reply": "query result too large to be processed by llm, the query results are in our query.csv file. If you need to display the results directly, perhaps use the table_generation_func function."}
|
158 |
+
else:
|
159 |
+
return {"reply": result["results"][0]}
|
160 |
+
|
161 |
+
except Exception as e:
|
162 |
+
reply = f"""There was an error running the NoSQL (Mongo) Query = {aggregation_pipeline}
|
163 |
+
The error is {e},
|
164 |
+
You should probably try again.
|
165 |
+
"""
|
166 |
+
print(reply)
|
167 |
+
return {"reply": reply}
|
requirements.txt
CHANGED
@@ -7,3 +7,6 @@ openpyxl
|
|
7 |
statsmodels
|
8 |
xlrd
|
9 |
psycopg2-binary
|
|
|
|
|
|
|
|
7 |
statsmodels
|
8 |
xlrd
|
9 |
psycopg2-binary
|
10 |
+
pymongo
|
11 |
+
pymongoarrow
|
12 |
+
pymongo_schema
|
templates/doc_db.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ast
|
2 |
+
import gradio as gr
|
3 |
+
from functions import doc_db_example_question_generator, doc_db_chatbot_with_fc
|
4 |
+
from data_sources import connect_doc_db
|
5 |
+
from utils import message_dict
|
6 |
+
|
7 |
+
def hide_info():
|
8 |
+
return gr.update(visible=False)
|
9 |
+
|
10 |
+
with gr.Blocks() as demo:
|
11 |
+
description = gr.HTML("""
|
12 |
+
<!-- Header -->
|
13 |
+
<div class="max-w-4xl mx-auto mb-12 text-center">
|
14 |
+
<div class="bg-blue-50 border border-blue-200 rounded-lg max-w-2xl mx-auto">
|
15 |
+
<p>This tool allows users to communicate with and query real time data from a Document DB (MongoDB for now, others can be added if requested) using natural
|
16 |
+
language and the above features.</p>
|
17 |
+
<p style="font-weight:bold;">Notice: the way this system is designed, no login information is retained and credentials are passed as session variables until the user leaves or
|
18 |
+
refreshes the page in which they disappear. They are never saved to any files. I also make use of the PyMongoArrow aggregate_pandas_all function to apply pipelines,
|
19 |
+
which can't delete, drop, or add database lines to avoid unhappy accidents or glitches.
|
20 |
+
That being said, it's probably not a good idea to connect a production database to a strange AI tool with an unfamiliar author.
|
21 |
+
This should be for demonstration purposes.</p>
|
22 |
+
<p>Contact me if this is something you would like built in your organization, on your infrastructure, and with the requisite privacy and control a production
|
23 |
+
database analytics tool requires.</p>
|
24 |
+
</div>
|
25 |
+
</div>
|
26 |
+
""", elem_classes="description_component")
|
27 |
+
connection_string = gr.Textbox(label="Connection String", value="dataanalyst0.l1klmww.mongodb.net/")
|
28 |
+
with gr.Row():
|
29 |
+
connection_user = gr.Textbox(label="Connection User", value="virtual-data-analyst")
|
30 |
+
connection_password = gr.Textbox(label="Connection Password", value="zcpbmoGJ3mC8o", type="password")
|
31 |
+
doc_db_name = gr.Textbox(label="Database Name", value="sample_mflix")
|
32 |
+
|
33 |
+
submit = gr.Button(value="Submit")
|
34 |
+
submit.click(fn=hide_info, outputs=description)
|
35 |
+
|
36 |
+
@gr.render(inputs=[connection_string,connection_user,connection_password,doc_db_name], triggers=[submit.click])
|
37 |
+
def sql_chat(request: gr.Request, connection_string=connection_string.value, connection_user=connection_user.value, connection_password=connection_password.value, doc_db_name=doc_db_name.value):
|
38 |
+
if request.session_hash not in message_dict:
|
39 |
+
message_dict[request.session_hash] = {}
|
40 |
+
message_dict[request.session_hash]['doc_db'] = None
|
41 |
+
connection_login_value = "mongodb+srv://" + connection_user + ":" + connection_password + "@" + connection_string
|
42 |
+
if connection_login_value:
|
43 |
+
print("MONGO APP")
|
44 |
+
process_message = process_doc_db(connection_login_value, doc_db_name, request.session_hash)
|
45 |
+
gr.HTML(value=process_message[1], padding=False)
|
46 |
+
if process_message[0] == "success":
|
47 |
+
if "dataanalyst0.l1klmww.mongodb.net" in connection_login_value:
|
48 |
+
example_questions = [
|
49 |
+
["Describe the dataset"],
|
50 |
+
["What are the top 5 most common movie genres?"],
|
51 |
+
["How do user comment counts on a movie correlate with the movie award wins?"],
|
52 |
+
["Can you generate a pie chart showing the top 10 states with the most movie theaters?"],
|
53 |
+
["What are the top 10 most represented directors in the database?"],
|
54 |
+
["What are the different movie categories and how many movies are in each category?"]
|
55 |
+
]
|
56 |
+
else:
|
57 |
+
try:
|
58 |
+
generated_examples = ast.literal_eval(doc_db_example_question_generator(request.session_hash, process_message[2], doc_db_name, process_message[3]))
|
59 |
+
example_questions = [
|
60 |
+
["Describe the dataset"]
|
61 |
+
]
|
62 |
+
for example in generated_examples:
|
63 |
+
example_questions.append([example])
|
64 |
+
except Exception as e:
|
65 |
+
print("DOC DB QUESTION GENERATION ERROR")
|
66 |
+
print(e)
|
67 |
+
example_questions = [
|
68 |
+
["Describe the dataset"],
|
69 |
+
["List the columns in the dataset"],
|
70 |
+
["What could this data be used for?"],
|
71 |
+
]
|
72 |
+
session_hash = gr.Textbox(visible=False, value=request.session_hash)
|
73 |
+
db_connection_string = gr.Textbox(visible=False, value=connection_login_value)
|
74 |
+
db_name = gr.Textbox(visible=False, value=doc_db_name)
|
75 |
+
db_collections = gr.Textbox(value=process_message[2], interactive=False, label="DB Collections")
|
76 |
+
db_schema = gr.Textbox(visible=False, value=process_message[3])
|
77 |
+
bot = gr.Chatbot(type='messages', label="CSV Chat Window", render_markdown=True, sanitize_html=False, show_label=True, render=False, visible=True, elem_classes="chatbot")
|
78 |
+
chat = gr.ChatInterface(
|
79 |
+
fn=doc_db_chatbot_with_fc,
|
80 |
+
type='messages',
|
81 |
+
chatbot=bot,
|
82 |
+
title="Chat with your Database",
|
83 |
+
examples=example_questions,
|
84 |
+
concurrency_limit=None,
|
85 |
+
additional_inputs=[session_hash, db_connection_string, db_name, db_collections,db_schema]
|
86 |
+
)
|
87 |
+
|
88 |
+
def process_doc_db(connection_string, nosql_db_name, session_hash):
|
89 |
+
if connection_string:
|
90 |
+
process_message = connect_doc_db(connection_string, nosql_db_name, session_hash)
|
91 |
+
return process_message
|
92 |
+
|
93 |
+
if __name__ == "__main__":
|
94 |
+
demo.launch()
|
tools/tools.py
CHANGED
@@ -57,7 +57,7 @@ def sql_tools_call(db_tables):
|
|
57 |
"function": {
|
58 |
"name": "sql_query_func",
|
59 |
"description": f"""This is a tool useful to query a PostgreSQL database with the following tables, {table_string}.
|
60 |
-
There may also be more tables in the database if the number of
|
61 |
This function also saves the results of the query to csv file called query.csv.""",
|
62 |
"parameters": {
|
63 |
"type": "object",
|
@@ -79,4 +79,39 @@ def sql_tools_call(db_tables):
|
|
79 |
tools_calls.extend(chart_tools)
|
80 |
tools_calls.extend(stats_tools)
|
81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
return tools_calls
|
|
|
57 |
"function": {
|
58 |
"name": "sql_query_func",
|
59 |
"description": f"""This is a tool useful to query a PostgreSQL database with the following tables, {table_string}.
|
60 |
+
There may also be more tables in the database if the number of tables is too large to process.
|
61 |
This function also saves the results of the query to csv file called query.csv.""",
|
62 |
"parameters": {
|
63 |
"type": "object",
|
|
|
79 |
tools_calls.extend(chart_tools)
|
80 |
tools_calls.extend(stats_tools)
|
81 |
|
82 |
+
return tools_calls
|
83 |
+
|
84 |
+
def doc_db_tools_call(db_collections):
|
85 |
+
|
86 |
+
collection_string = (db_collections[:625] + '..') if len(db_collections) > 625 else db_collections
|
87 |
+
|
88 |
+
tools_calls = [
|
89 |
+
{
|
90 |
+
"type": "function",
|
91 |
+
"function": {
|
92 |
+
"name": "doc_db_query_func",
|
93 |
+
"description": f"""This is a tool useful to build an aggregation pipeline to query a MongoDB NoSQL document database with the following collections, {collection_string}.
|
94 |
+
There may also be more collections in the database if the number of tables is too large to process.
|
95 |
+
This function also saves the results of the query to a csv file called query.csv.""",
|
96 |
+
"parameters": {
|
97 |
+
"type": "object",
|
98 |
+
"properties": {
|
99 |
+
"aggregation_pipeline": {
|
100 |
+
"type": "string",
|
101 |
+
"description": "The MongoDB aggregation pipeline to use in the search. Infer this from the user's message. It should be a question or a statement"
|
102 |
+
},
|
103 |
+
"db_collection": {
|
104 |
+
"type": "string",
|
105 |
+
"description": "The MongoDB collection to use in the search. Infer this from the user's message. It should be a question or a statement",
|
106 |
+
}
|
107 |
+
},
|
108 |
+
"required": ["queries","db_collection"],
|
109 |
+
},
|
110 |
+
},
|
111 |
+
},
|
112 |
+
]
|
113 |
+
|
114 |
+
tools_calls.extend(chart_tools)
|
115 |
+
tools_calls.extend(stats_tools)
|
116 |
+
|
117 |
return tools_calls
|