Ashhar
commited on
Commit
·
6d149f9
1
Parent(s):
4380c2b
support multiple tables/views
Browse files
app.py
CHANGED
@@ -3,7 +3,6 @@ import os
|
|
3 |
import pandas as pd
|
4 |
from typing import Literal, TypedDict
|
5 |
from sqlalchemy import create_engine, inspect, text
|
6 |
-
import json
|
7 |
from transformers import AutoTokenizer
|
8 |
from utils import pprint
|
9 |
import time
|
@@ -34,7 +33,7 @@ ModelConfig = TypedDict("ModelConfig", {
|
|
34 |
})
|
35 |
|
36 |
MODEL_CONFIG: dict[ModelType, ModelConfig] = {
|
37 |
-
"
|
38 |
"client": anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")),
|
39 |
"model": "claude-3-5-haiku-20241022",
|
40 |
# "model": "claude-3-5-sonnet-20241022",
|
@@ -42,6 +41,14 @@ MODEL_CONFIG: dict[ModelType, ModelConfig] = {
|
|
42 |
"max_context": 40000,
|
43 |
"tokenizer": AutoTokenizer.from_pretrained("Xenova/claude-tokenizer")
|
44 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
"GPT_4o": {
|
46 |
"client": OpenAI(api_key=os.environ.get("OPENAI_API_KEY")),
|
47 |
"model": "gpt-4o",
|
@@ -111,7 +118,7 @@ TOOLS_MODEL = MODEL_CONFIG[modelType].get("tools_model") or MODEL
|
|
111 |
MAX_CONTEXT = MODEL_CONFIG[modelType]["max_context"]
|
112 |
tokenizer = MODEL_CONFIG[modelType]["tokenizer"]
|
113 |
|
114 |
-
isClaudeModel = modelType
|
115 |
isDeepSeekModel = modelType.startswith("DEEPSEEK")
|
116 |
|
117 |
|
@@ -211,7 +218,7 @@ def get_table_schema(table_name):
|
|
211 |
|
212 |
def get_sample_data(table_name):
|
213 |
if not st.session_state.engine:
|
214 |
-
return None
|
215 |
|
216 |
query = f"SELECT * FROM {table_name} ORDER BY 1 DESC LIMIT 3"
|
217 |
try:
|
@@ -219,8 +226,8 @@ def get_sample_data(table_name):
|
|
219 |
df = pd.read_sql(query, conn)
|
220 |
return df
|
221 |
except Exception as e:
|
222 |
-
st.error(f"Error fetching sample data: {str(e)}")
|
223 |
-
return
|
224 |
|
225 |
|
226 |
def clean_sql_response(response: str) -> str:
|
@@ -254,28 +261,57 @@ def execute_query(query):
|
|
254 |
|
255 |
|
256 |
def generate_sql_query(user_query):
|
257 |
-
|
258 |
-
|
259 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
260 |
|
261 |
-
|
262 |
-
{json.dumps(st.session_state.table_schema, indent=2)}
|
263 |
|
264 |
-
|
265 |
-
{
|
266 |
|
267 |
Important:
|
268 |
1. Only return the SQL query, nothing else
|
269 |
2. The query should be valid PostgreSQL syntax
|
270 |
3. Do not include any explanations or comments
|
271 |
4. Make sure to handle NULL values appropriately
|
272 |
-
5.
|
|
|
273 |
|
274 |
User Query: {user_query}
|
275 |
"""
|
276 |
-
|
277 |
prompt_tokens = __countTokens(prompt)
|
278 |
-
|
|
|
279 |
|
280 |
# Debug prompt in a Streamlit expander for better organization
|
281 |
# Check if running locally based on Streamlit's origin header
|
@@ -340,56 +376,98 @@ if st.session_state.connection_string:
|
|
340 |
db_objects = [(table, 'Table') for table in tables] + [(view, 'View') for view in views]
|
341 |
db_objects.sort(key=lambda x: x[0]) # Sort alphabetically by name
|
342 |
|
343 |
-
#
|
344 |
-
|
|
|
|
|
|
|
345 |
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
|
|
352 |
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
|
|
|
|
357 |
|
358 |
# Create containers for schema and data
|
359 |
schema_container = st.container()
|
360 |
data_container = st.container()
|
361 |
|
362 |
-
#
|
363 |
-
if
|
364 |
-
#
|
365 |
-
if
|
366 |
-
st.session_state.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
367 |
|
368 |
-
#
|
369 |
-
|
370 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
371 |
|
372 |
-
#
|
373 |
with schema_container:
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
|
|
|
|
|
|
379 |
|
380 |
with data_container:
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
|
|
|
|
|
|
390 |
|
391 |
# Query Input Section
|
392 |
-
if st.session_state.
|
393 |
st.header("3. Query Input")
|
394 |
user_query = st.text_area("Enter your query in plain English")
|
395 |
|
|
|
3 |
import pandas as pd
|
4 |
from typing import Literal, TypedDict
|
5 |
from sqlalchemy import create_engine, inspect, text
|
|
|
6 |
from transformers import AutoTokenizer
|
7 |
from utils import pprint
|
8 |
import time
|
|
|
33 |
})
|
34 |
|
35 |
MODEL_CONFIG: dict[ModelType, ModelConfig] = {
|
36 |
+
"CLAUDE_HAIKU": {
|
37 |
"client": anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")),
|
38 |
"model": "claude-3-5-haiku-20241022",
|
39 |
# "model": "claude-3-5-sonnet-20241022",
|
|
|
41 |
"max_context": 40000,
|
42 |
"tokenizer": AutoTokenizer.from_pretrained("Xenova/claude-tokenizer")
|
43 |
},
|
44 |
+
"CLAUDE_SONNET": {
|
45 |
+
"client": anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")),
|
46 |
+
# "model": "claude-3-5-haiku-20241022",
|
47 |
+
# "model": "claude-3-5-sonnet-20241022",
|
48 |
+
"model": "claude-3-5-sonnet-20240620",
|
49 |
+
"max_context": 40000,
|
50 |
+
"tokenizer": AutoTokenizer.from_pretrained("Xenova/claude-tokenizer")
|
51 |
+
},
|
52 |
"GPT_4o": {
|
53 |
"client": OpenAI(api_key=os.environ.get("OPENAI_API_KEY")),
|
54 |
"model": "gpt-4o",
|
|
|
118 |
MAX_CONTEXT = MODEL_CONFIG[modelType]["max_context"]
|
119 |
tokenizer = MODEL_CONFIG[modelType]["tokenizer"]
|
120 |
|
121 |
+
isClaudeModel = modelType.startswith("CLAUDE")
|
122 |
isDeepSeekModel = modelType.startswith("DEEPSEEK")
|
123 |
|
124 |
|
|
|
218 |
|
219 |
def get_sample_data(table_name):
|
220 |
if not st.session_state.engine:
|
221 |
+
return pd.DataFrame() # Return empty DataFrame instead of None
|
222 |
|
223 |
query = f"SELECT * FROM {table_name} ORDER BY 1 DESC LIMIT 3"
|
224 |
try:
|
|
|
226 |
df = pd.read_sql(query, conn)
|
227 |
return df
|
228 |
except Exception as e:
|
229 |
+
st.error(f"Error fetching sample data for {table_name}: {str(e)}")
|
230 |
+
return pd.DataFrame() # Return empty DataFrame on error
|
231 |
|
232 |
|
233 |
def clean_sql_response(response: str) -> str:
|
|
|
261 |
|
262 |
|
263 |
def generate_sql_query(user_query):
|
264 |
+
# Build context for all selected tables
|
265 |
+
tables_context = []
|
266 |
+
for table_name, table_type in st.session_state.selected_tables.items():
|
267 |
+
# Format schema in markdown
|
268 |
+
schema_info = st.session_state.table_schemas[table_name]
|
269 |
+
|
270 |
+
# Build markdown formatted schema
|
271 |
+
schema_md = [f"\n\n### {table_type}: {table_name}"]
|
272 |
+
|
273 |
+
# Add table comment if exists
|
274 |
+
if schema_info.get("table_comment"):
|
275 |
+
schema_md.append(f"> {schema_info['table_comment']}")
|
276 |
+
|
277 |
+
# Add column details
|
278 |
+
schema_md.append("\n**Columns:**")
|
279 |
+
for col_name, col_info in schema_info["columns"].items():
|
280 |
+
col_type = col_info["type"]
|
281 |
+
col_comment = col_info.get("comment")
|
282 |
+
|
283 |
+
# Format column with type and optional comment
|
284 |
+
if col_comment:
|
285 |
+
schema_md.append(f"- `{col_name}` ({col_type}) - {col_comment}")
|
286 |
+
else:
|
287 |
+
schema_md.append(f"- `{col_name}` ({col_type})")
|
288 |
+
|
289 |
+
# Add sample data
|
290 |
+
schema_md.append("\n**Sample Data:**")
|
291 |
+
schema_md.append(st.session_state.sample_data[table_name].to_markdown(index=False))
|
292 |
+
|
293 |
+
# Join all parts with newlines
|
294 |
+
tables_context.append("\n".join(schema_md))
|
295 |
|
296 |
+
prompt = f"""You are a SQL expert. Generate a valid PostgreSQL query based on the following context and user query.
|
|
|
297 |
|
298 |
+
<AVAILABLE_OBJECTS>
|
299 |
+
{chr(10).join(tables_context)}
|
300 |
|
301 |
Important:
|
302 |
1. Only return the SQL query, nothing else
|
303 |
2. The query should be valid PostgreSQL syntax
|
304 |
3. Do not include any explanations or comments
|
305 |
4. Make sure to handle NULL values appropriately
|
306 |
+
5. If joining tables, use appropriate join conditions based on the schema
|
307 |
+
6. Use table names with appropriate qualifiers to avoid ambiguity
|
308 |
|
309 |
User Query: {user_query}
|
310 |
"""
|
311 |
+
|
312 |
prompt_tokens = __countTokens(prompt)
|
313 |
+
print("\n")
|
314 |
+
pprint(f"[{MODEL}] Prompt tokens for SQL generation: {prompt_tokens}")
|
315 |
|
316 |
# Debug prompt in a Streamlit expander for better organization
|
317 |
# Check if running locally based on Streamlit's origin header
|
|
|
376 |
db_objects = [(table, 'Table') for table in tables] + [(view, 'View') for view in views]
|
377 |
db_objects.sort(key=lambda x: x[0]) # Sort alphabetically by name
|
378 |
|
379 |
+
# Extract just the names for the multiselect
|
380 |
+
object_names = [obj[0] for obj in db_objects]
|
381 |
+
|
382 |
+
# Default to 'lsq_leads' if present
|
383 |
+
default_selections = ['lsq_leads'] if 'lsq_leads' in object_names else []
|
384 |
|
385 |
+
# Create multiselect for table/view selection
|
386 |
+
selected_objects = st.multiselect(
|
387 |
+
"Select tables/views",
|
388 |
+
options=object_names,
|
389 |
+
default=default_selections,
|
390 |
+
help="You can select multiple tables/views to query across them"
|
391 |
+
)
|
392 |
|
393 |
+
# Display selected object types
|
394 |
+
if selected_objects:
|
395 |
+
st.write("Selected objects:")
|
396 |
+
for obj in selected_objects:
|
397 |
+
obj_type = next(obj_type for obj_name, obj_type in db_objects if obj_name == obj)
|
398 |
+
st.write(f"- {obj}: {obj_type}")
|
399 |
|
400 |
# Create containers for schema and data
|
401 |
schema_container = st.container()
|
402 |
data_container = st.container()
|
403 |
|
404 |
+
# Initialize or reset session state for selected objects
|
405 |
+
if selected_objects:
|
406 |
+
# Always ensure dictionaries exist in session state
|
407 |
+
if not isinstance(st.session_state.get("selected_tables"), dict):
|
408 |
+
st.session_state.selected_tables = {}
|
409 |
+
if not isinstance(st.session_state.get("table_schemas"), dict):
|
410 |
+
st.session_state.table_schemas = {}
|
411 |
+
if not isinstance(st.session_state.get("sample_data"), dict):
|
412 |
+
st.session_state.sample_data = {}
|
413 |
+
|
414 |
+
# Clear previous data for tables that are no longer selected
|
415 |
+
current_tables = set(selected_objects)
|
416 |
+
previous_tables = set(st.session_state.selected_tables.keys())
|
417 |
+
removed_tables = previous_tables - current_tables
|
418 |
+
|
419 |
+
for table in removed_tables:
|
420 |
+
if table in st.session_state.selected_tables:
|
421 |
+
del st.session_state.selected_tables[table]
|
422 |
+
if table in st.session_state.table_schemas:
|
423 |
+
del st.session_state.table_schemas[table]
|
424 |
+
if table in st.session_state.sample_data:
|
425 |
+
del st.session_state.sample_data[table]
|
426 |
|
427 |
+
# Update session state with new selections
|
428 |
+
for obj in selected_objects:
|
429 |
+
# Update selected tables
|
430 |
+
st.session_state.selected_tables[obj] = next(
|
431 |
+
obj_type for obj_name, obj_type in db_objects if obj_name == obj
|
432 |
+
)
|
433 |
+
|
434 |
+
# Fetch and store schema
|
435 |
+
schema = get_table_schema(obj)
|
436 |
+
if schema:
|
437 |
+
st.session_state.table_schemas[obj] = schema
|
438 |
+
|
439 |
+
# Fetch and store sample data
|
440 |
+
sample_data = get_sample_data(obj)
|
441 |
+
if not sample_data.empty:
|
442 |
+
st.session_state.sample_data[obj] = sample_data
|
443 |
|
444 |
+
# Display schema and sample data for each selected object
|
445 |
with schema_container:
|
446 |
+
st.subheader("Table/View Schemas")
|
447 |
+
for obj in selected_objects:
|
448 |
+
if obj in st.session_state.table_schemas:
|
449 |
+
st.write(f"**{obj} Schema:**")
|
450 |
+
st.json(st.session_state.table_schemas[obj])
|
451 |
+
st.write("---")
|
452 |
+
else:
|
453 |
+
st.warning(f"Could not fetch schema for {obj}")
|
454 |
|
455 |
with data_container:
|
456 |
+
st.subheader("Sample Data")
|
457 |
+
for obj in selected_objects:
|
458 |
+
if obj in st.session_state.sample_data and not st.session_state.sample_data[obj].empty:
|
459 |
+
st.write(f"**{obj} (Last 3 rows):**")
|
460 |
+
st.dataframe(
|
461 |
+
st.session_state.sample_data[obj],
|
462 |
+
use_container_width=True,
|
463 |
+
hide_index=True
|
464 |
+
)
|
465 |
+
st.write("---")
|
466 |
+
else:
|
467 |
+
st.warning(f"No sample data available for {obj}")
|
468 |
|
469 |
# Query Input Section
|
470 |
+
if st.session_state.get("selected_tables"):
|
471 |
st.header("3. Query Input")
|
472 |
user_query = st.text_area("Enter your query in plain English")
|
473 |
|