Ashhar commited on
Commit
6d149f9
·
1 Parent(s): 4380c2b

support multiple tables/views

Browse files
Files changed (1) hide show
  1. app.py +130 -52
app.py CHANGED
@@ -3,7 +3,6 @@ import os
3
  import pandas as pd
4
  from typing import Literal, TypedDict
5
  from sqlalchemy import create_engine, inspect, text
6
- import json
7
  from transformers import AutoTokenizer
8
  from utils import pprint
9
  import time
@@ -34,7 +33,7 @@ ModelConfig = TypedDict("ModelConfig", {
34
  })
35
 
36
  MODEL_CONFIG: dict[ModelType, ModelConfig] = {
37
- "CLAUDE": {
38
  "client": anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")),
39
  "model": "claude-3-5-haiku-20241022",
40
  # "model": "claude-3-5-sonnet-20241022",
@@ -42,6 +41,14 @@ MODEL_CONFIG: dict[ModelType, ModelConfig] = {
42
  "max_context": 40000,
43
  "tokenizer": AutoTokenizer.from_pretrained("Xenova/claude-tokenizer")
44
  },
 
 
 
 
 
 
 
 
45
  "GPT_4o": {
46
  "client": OpenAI(api_key=os.environ.get("OPENAI_API_KEY")),
47
  "model": "gpt-4o",
@@ -111,7 +118,7 @@ TOOLS_MODEL = MODEL_CONFIG[modelType].get("tools_model") or MODEL
111
  MAX_CONTEXT = MODEL_CONFIG[modelType]["max_context"]
112
  tokenizer = MODEL_CONFIG[modelType]["tokenizer"]
113
 
114
- isClaudeModel = modelType == "CLAUDE"
115
  isDeepSeekModel = modelType.startswith("DEEPSEEK")
116
 
117
 
@@ -211,7 +218,7 @@ def get_table_schema(table_name):
211
 
212
  def get_sample_data(table_name):
213
  if not st.session_state.engine:
214
- return None
215
 
216
  query = f"SELECT * FROM {table_name} ORDER BY 1 DESC LIMIT 3"
217
  try:
@@ -219,8 +226,8 @@ def get_sample_data(table_name):
219
  df = pd.read_sql(query, conn)
220
  return df
221
  except Exception as e:
222
- st.error(f"Error fetching sample data: {str(e)}")
223
- return None
224
 
225
 
226
  def clean_sql_response(response: str) -> str:
@@ -254,28 +261,57 @@ def execute_query(query):
254
 
255
 
256
  def generate_sql_query(user_query):
257
- prompt = f"""You are a SQL expert. Generate a valid PostgreSQL query based on the following context and user query.
258
-
259
- Table Name: {st.session_state.selected_table}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
 
261
- Table Schema:
262
- {json.dumps(st.session_state.table_schema, indent=2)}
263
 
264
- Sample Data:
265
- {st.session_state.sample_data.to_markdown(index=False)}
266
 
267
  Important:
268
  1. Only return the SQL query, nothing else
269
  2. The query should be valid PostgreSQL syntax
270
  3. Do not include any explanations or comments
271
  4. Make sure to handle NULL values appropriately
272
- 5. Use the table name '{st.session_state.selected_table}' in your query
 
273
 
274
  User Query: {user_query}
275
  """
276
-
277
  prompt_tokens = __countTokens(prompt)
278
- pprint(f"\n[{MODEL}] Prompt tokens for SQL generation: {prompt_tokens}")
 
279
 
280
  # Debug prompt in a Streamlit expander for better organization
281
  # Check if running locally based on Streamlit's origin header
@@ -340,56 +376,98 @@ if st.session_state.connection_string:
340
  db_objects = [(table, 'Table') for table in tables] + [(view, 'View') for view in views]
341
  db_objects.sort(key=lambda x: x[0]) # Sort alphabetically by name
342
 
343
- # Create two columns for the selection
344
- col1, col2 = st.columns([3, 1])
 
 
 
345
 
346
- with col1:
347
- # Extract just the names for the selectbox
348
- object_names = [obj[0] for obj in db_objects]
349
- # Set default index to 'lsq_leads' if present, otherwise 0
350
- default_index = object_names.index('lsq_leads') if 'lsq_leads' in object_names else 0
351
- selected_object = st.selectbox("Select a table or view", object_names, index=default_index)
 
352
 
353
- with col2:
354
- # Display the object type (Table/View)
355
- object_type = next(obj_type for obj_name, obj_type in db_objects if obj_name == selected_object)
356
- st.text_input("Type", value=object_type, disabled=True)
 
 
357
 
358
  # Create containers for schema and data
359
  schema_container = st.container()
360
  data_container = st.container()
361
 
362
- # Always load object data if we have a selection
363
- if selected_object:
364
- # Update session state
365
- if selected_object != st.session_state.selected_table:
366
- st.session_state.selected_table = selected_object
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
 
368
- # Always fetch schema and sample data
369
- st.session_state.table_schema = get_table_schema(selected_object)
370
- st.session_state.sample_data = get_sample_data(selected_object)
 
 
 
 
 
 
 
 
 
 
 
 
 
371
 
372
- # Always display schema and sample data if available
373
  with schema_container:
374
- if st.session_state.table_schema:
375
- st.subheader("Table Schema")
376
- # Force immediate rendering with an empty element
377
- st.empty()
378
- st.json(st.session_state.table_schema)
 
 
 
379
 
380
  with data_container:
381
- if st.session_state.sample_data is not None:
382
- st.subheader("Sample Data (Last 3 rows)")
383
- # Force immediate rendering with an empty element
384
- st.empty()
385
- st.dataframe(
386
- st.session_state.sample_data,
387
- use_container_width=True,
388
- hide_index=True
389
- )
 
 
 
390
 
391
  # Query Input Section
392
- if st.session_state.selected_table:
393
  st.header("3. Query Input")
394
  user_query = st.text_area("Enter your query in plain English")
395
 
 
3
  import pandas as pd
4
  from typing import Literal, TypedDict
5
  from sqlalchemy import create_engine, inspect, text
 
6
  from transformers import AutoTokenizer
7
  from utils import pprint
8
  import time
 
33
  })
34
 
35
  MODEL_CONFIG: dict[ModelType, ModelConfig] = {
36
+ "CLAUDE_HAIKU": {
37
  "client": anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")),
38
  "model": "claude-3-5-haiku-20241022",
39
  # "model": "claude-3-5-sonnet-20241022",
 
41
  "max_context": 40000,
42
  "tokenizer": AutoTokenizer.from_pretrained("Xenova/claude-tokenizer")
43
  },
44
+ "CLAUDE_SONNET": {
45
+ "client": anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")),
46
+ # "model": "claude-3-5-haiku-20241022",
47
+ # "model": "claude-3-5-sonnet-20241022",
48
+ "model": "claude-3-5-sonnet-20240620",
49
+ "max_context": 40000,
50
+ "tokenizer": AutoTokenizer.from_pretrained("Xenova/claude-tokenizer")
51
+ },
52
  "GPT_4o": {
53
  "client": OpenAI(api_key=os.environ.get("OPENAI_API_KEY")),
54
  "model": "gpt-4o",
 
118
  MAX_CONTEXT = MODEL_CONFIG[modelType]["max_context"]
119
  tokenizer = MODEL_CONFIG[modelType]["tokenizer"]
120
 
121
+ isClaudeModel = modelType.startswith("CLAUDE")
122
  isDeepSeekModel = modelType.startswith("DEEPSEEK")
123
 
124
 
 
218
 
219
  def get_sample_data(table_name):
220
  if not st.session_state.engine:
221
+ return pd.DataFrame() # Return empty DataFrame instead of None
222
 
223
  query = f"SELECT * FROM {table_name} ORDER BY 1 DESC LIMIT 3"
224
  try:
 
226
  df = pd.read_sql(query, conn)
227
  return df
228
  except Exception as e:
229
+ st.error(f"Error fetching sample data for {table_name}: {str(e)}")
230
+ return pd.DataFrame() # Return empty DataFrame on error
231
 
232
 
233
  def clean_sql_response(response: str) -> str:
 
261
 
262
 
263
  def generate_sql_query(user_query):
264
+ # Build context for all selected tables
265
+ tables_context = []
266
+ for table_name, table_type in st.session_state.selected_tables.items():
267
+ # Format schema in markdown
268
+ schema_info = st.session_state.table_schemas[table_name]
269
+
270
+ # Build markdown formatted schema
271
+ schema_md = [f"\n\n### {table_type}: {table_name}"]
272
+
273
+ # Add table comment if exists
274
+ if schema_info.get("table_comment"):
275
+ schema_md.append(f"> {schema_info['table_comment']}")
276
+
277
+ # Add column details
278
+ schema_md.append("\n**Columns:**")
279
+ for col_name, col_info in schema_info["columns"].items():
280
+ col_type = col_info["type"]
281
+ col_comment = col_info.get("comment")
282
+
283
+ # Format column with type and optional comment
284
+ if col_comment:
285
+ schema_md.append(f"- `{col_name}` ({col_type}) - {col_comment}")
286
+ else:
287
+ schema_md.append(f"- `{col_name}` ({col_type})")
288
+
289
+ # Add sample data
290
+ schema_md.append("\n**Sample Data:**")
291
+ schema_md.append(st.session_state.sample_data[table_name].to_markdown(index=False))
292
+
293
+ # Join all parts with newlines
294
+ tables_context.append("\n".join(schema_md))
295
 
296
+ prompt = f"""You are a SQL expert. Generate a valid PostgreSQL query based on the following context and user query.
 
297
 
298
+ <AVAILABLE_OBJECTS>
299
+ {chr(10).join(tables_context)}
300
 
301
  Important:
302
  1. Only return the SQL query, nothing else
303
  2. The query should be valid PostgreSQL syntax
304
  3. Do not include any explanations or comments
305
  4. Make sure to handle NULL values appropriately
306
+ 5. If joining tables, use appropriate join conditions based on the schema
307
+ 6. Use table names with appropriate qualifiers to avoid ambiguity
308
 
309
  User Query: {user_query}
310
  """
311
+
312
  prompt_tokens = __countTokens(prompt)
313
+ print("\n")
314
+ pprint(f"[{MODEL}] Prompt tokens for SQL generation: {prompt_tokens}")
315
 
316
  # Debug prompt in a Streamlit expander for better organization
317
  # Check if running locally based on Streamlit's origin header
 
376
  db_objects = [(table, 'Table') for table in tables] + [(view, 'View') for view in views]
377
  db_objects.sort(key=lambda x: x[0]) # Sort alphabetically by name
378
 
379
+ # Extract just the names for the multiselect
380
+ object_names = [obj[0] for obj in db_objects]
381
+
382
+ # Default to 'lsq_leads' if present
383
+ default_selections = ['lsq_leads'] if 'lsq_leads' in object_names else []
384
 
385
+ # Create multiselect for table/view selection
386
+ selected_objects = st.multiselect(
387
+ "Select tables/views",
388
+ options=object_names,
389
+ default=default_selections,
390
+ help="You can select multiple tables/views to query across them"
391
+ )
392
 
393
+ # Display selected object types
394
+ if selected_objects:
395
+ st.write("Selected objects:")
396
+ for obj in selected_objects:
397
+ obj_type = next(obj_type for obj_name, obj_type in db_objects if obj_name == obj)
398
+ st.write(f"- {obj}: {obj_type}")
399
 
400
  # Create containers for schema and data
401
  schema_container = st.container()
402
  data_container = st.container()
403
 
404
+ # Initialize or reset session state for selected objects
405
+ if selected_objects:
406
+ # Always ensure dictionaries exist in session state
407
+ if not isinstance(st.session_state.get("selected_tables"), dict):
408
+ st.session_state.selected_tables = {}
409
+ if not isinstance(st.session_state.get("table_schemas"), dict):
410
+ st.session_state.table_schemas = {}
411
+ if not isinstance(st.session_state.get("sample_data"), dict):
412
+ st.session_state.sample_data = {}
413
+
414
+ # Clear previous data for tables that are no longer selected
415
+ current_tables = set(selected_objects)
416
+ previous_tables = set(st.session_state.selected_tables.keys())
417
+ removed_tables = previous_tables - current_tables
418
+
419
+ for table in removed_tables:
420
+ if table in st.session_state.selected_tables:
421
+ del st.session_state.selected_tables[table]
422
+ if table in st.session_state.table_schemas:
423
+ del st.session_state.table_schemas[table]
424
+ if table in st.session_state.sample_data:
425
+ del st.session_state.sample_data[table]
426
 
427
+ # Update session state with new selections
428
+ for obj in selected_objects:
429
+ # Update selected tables
430
+ st.session_state.selected_tables[obj] = next(
431
+ obj_type for obj_name, obj_type in db_objects if obj_name == obj
432
+ )
433
+
434
+ # Fetch and store schema
435
+ schema = get_table_schema(obj)
436
+ if schema:
437
+ st.session_state.table_schemas[obj] = schema
438
+
439
+ # Fetch and store sample data
440
+ sample_data = get_sample_data(obj)
441
+ if not sample_data.empty:
442
+ st.session_state.sample_data[obj] = sample_data
443
 
444
+ # Display schema and sample data for each selected object
445
  with schema_container:
446
+ st.subheader("Table/View Schemas")
447
+ for obj in selected_objects:
448
+ if obj in st.session_state.table_schemas:
449
+ st.write(f"**{obj} Schema:**")
450
+ st.json(st.session_state.table_schemas[obj])
451
+ st.write("---")
452
+ else:
453
+ st.warning(f"Could not fetch schema for {obj}")
454
 
455
  with data_container:
456
+ st.subheader("Sample Data")
457
+ for obj in selected_objects:
458
+ if obj in st.session_state.sample_data and not st.session_state.sample_data[obj].empty:
459
+ st.write(f"**{obj} (Last 3 rows):**")
460
+ st.dataframe(
461
+ st.session_state.sample_data[obj],
462
+ use_container_width=True,
463
+ hide_index=True
464
+ )
465
+ st.write("---")
466
+ else:
467
+ st.warning(f"No sample data available for {obj}")
468
 
469
  # Query Input Section
470
+ if st.session_state.get("selected_tables"):
471
  st.header("3. Query Input")
472
  user_query = st.text_area("Enter your query in plain English")
473