import streamlit as st import os import pandas as pd from typing import Literal, TypedDict from sqlalchemy import create_engine, inspect, text from transformers import AutoTokenizer from utils import pprint import time import re from openai import OpenAI import anthropic from clients.openRouter import OpenRouter # Load environment variables from dotenv import load_dotenv load_dotenv() # Set up page configuration st.set_page_config( page_title="SQL Query Assistant", page_icon="💾", layout="centered", initial_sidebar_state="collapsed" ) ModelType = Literal["GPT_4o", "GPT_o1", "CLAUDE", "LLAMA", "DEEPSEEK", "DEEPSEEK_R1", "DEEPSEEK_R1_DISTILL"] ModelConfig = TypedDict("ModelConfig", { "client": OpenAI | anthropic.Anthropic, "model": str, "max_context": int, "tokenizer": AutoTokenizer }) MODEL_CONFIG: dict[ModelType, ModelConfig] = { "CLAUDE_HAIKU": { "client": anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")), "model": "claude-3-5-haiku-20241022", # "model": "claude-3-5-sonnet-20241022", # "model": "claude-3-5-sonnet-20240620", "max_context": 40000, "tokenizer": AutoTokenizer.from_pretrained("Xenova/claude-tokenizer") }, "CLAUDE_SONNET": { "client": anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")), # "model": "claude-3-5-haiku-20241022", # "model": "claude-3-5-sonnet-20241022", "model": "claude-3-5-sonnet-20240620", "max_context": 40000, "tokenizer": AutoTokenizer.from_pretrained("Xenova/claude-tokenizer") }, "GPT_4o": { "client": OpenAI(api_key=os.environ.get("OPENAI_API_KEY")), "model": "gpt-4o", "max_context": 15000, "tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o") }, # "GPT_o1": { # "client": OpenAI(api_key=os.environ.get("OPENAI_API_KEY")), # "model": "o1-preview", # "max_context": 15000, # "tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o") # }, "DEEPSEEK": { "client": OpenRouter( api_key=os.environ.get("OPENROUTER_API_KEY"), ), "model": "deepseek/deepseek-chat", "max_context": 30000, "tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o") }, "DEEPSEEK_R1": { "client": OpenRouter( api_key=os.environ.get("OPENROUTER_API_KEY"), ), "model": "deepseek/deepseek-r1", "max_context": 30000, "tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o") }, } def get_model_type(): """ Get the model type from Streamlit sidebar with model names """ # Get the available model types from the MODEL_CONFIG keys available_models = list(MODEL_CONFIG.keys()) # Create a list of display labels with just the model names model_display_labels = [ MODEL_CONFIG[model_type]['model'] for model_type in available_models ] # Add a sidebar selection for model name selected_model_name = st.sidebar.selectbox( "Select AI Model", model_display_labels, index=0 ) # Find the corresponding model type for the selected model name selected_model_type = next( model_type for model_type in available_models if MODEL_CONFIG[model_type]['model'] == selected_model_name ) return selected_model_type # In the main application flow, replace the previous modelType assignment modelType = get_model_type() client = MODEL_CONFIG[modelType]["client"] MODEL = MODEL_CONFIG[modelType]["model"] TOOLS_MODEL = MODEL_CONFIG[modelType].get("tools_model") or MODEL MAX_CONTEXT = MODEL_CONFIG[modelType]["max_context"] tokenizer = MODEL_CONFIG[modelType]["tokenizer"] isClaudeModel = modelType.startswith("CLAUDE") isDeepSeekModel = modelType.startswith("DEEPSEEK") def __countTokens(text): text = str(text) tokens = tokenizer.encode(text, add_special_tokens=False) return len(tokens) # Initialize session state variables if "ipAddress" not in st.session_state: st.session_state.ipAddress = st.context.headers.get("x-forwarded-for") if "connection_string" not in st.session_state: st.session_state.connection_string = None if "selected_table" not in st.session_state: st.session_state.selected_table = None if "table_schema" not in st.session_state: st.session_state.table_schema = None if "sample_data" not in st.session_state: st.session_state.sample_data = None if "engine" not in st.session_state: st.session_state.engine = None def connect_to_db(connection_string): try: engine = create_engine(connection_string) # Test the connection with engine.connect(): pass st.session_state.engine = engine return True except Exception as e: st.error(f"Failed to connect to database: {str(e)}") return False def get_table_schema(table_name): if not st.session_state.engine: return None inspector = inspect(st.session_state.engine) columns = inspector.get_columns(table_name) schema = {col['name']: str(col['type']) for col in columns} # Get table comment table_comment_query = """ SELECT obj_description(c.oid) as table_comment FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace WHERE c.relname = :table_name AND n.nspname = 'public' """ # Get column comments column_comments_query = """ SELECT cols.column_name, ( SELECT pg_catalog.col_description(c.oid, cols.ordinal_position::int) FROM pg_catalog.pg_class c WHERE c.oid = (SELECT ('"' || cols.table_name || '"')::regclass::oid) AND c.relname = cols.table_name ) as column_comment FROM information_schema.columns cols WHERE cols.table_name = :table_name AND cols.table_schema = 'public' """ try: with st.session_state.engine.connect() as conn: # Get table comment table_comment_result = conn.execute(text(table_comment_query), {"table_name": table_name}).fetchone() table_comment = table_comment_result[0] if table_comment_result else None # Get column comments column_comments_result = conn.execute(text(column_comments_query), {"table_name": table_name}).fetchall() column_comments = {row[0]: row[1] for row in column_comments_result} # Create enhanced schema dictionary enhanced_schema = { "table_comment": table_comment, "columns": { col_name: { "type": schema[col_name], "comment": column_comments.get(col_name) } for col_name in schema } } return enhanced_schema except Exception as e: st.error(f"Error fetching schema comments: {str(e)}") return schema # Fallback to basic schema if comment retrieval fails def get_sample_data(table_name): if not st.session_state.engine: return pd.DataFrame() # Return empty DataFrame instead of None query = f"SELECT * FROM {table_name} ORDER BY 1 DESC LIMIT 3" try: with st.session_state.engine.connect() as conn: df = pd.read_sql(query, conn) return df except Exception as e: st.error(f"Error fetching sample data for {table_name}: {str(e)}") return pd.DataFrame() # Return empty DataFrame on error def clean_sql_response(response: str) -> str: """Extract clean SQL query from a potentially formatted response.""" # If response contains SQL code block, extract it sql_block_match = re.search(r'```sql\n(.*?)\n```', response, re.DOTALL) if sql_block_match: return sql_block_match.group(1).strip() return response.strip() def is_read_only_query(query: str) -> bool: """Check if the query is read-only (SELECT only).""" # Convert query to uppercase for case-insensitive comparison query_upper = query.upper() # List of SQL statements that modify data modification_statements = [ 'INSERT', 'UPDATE', 'DELETE', 'DROP', 'CREATE', 'ALTER', 'TRUNCATE', 'REPLACE', 'MERGE', 'UPSERT', 'GRANT', 'REVOKE' ] # Check if query starts with any modification statement return not any(query_upper.strip().startswith(stmt) for stmt in modification_statements) def execute_query(query): if not st.session_state.engine: return None # Check if the query is read-only if not is_read_only_query(query): st.error("Error: Only SELECT queries are allowed for security reasons.") return None try: start_time = time.time() with st.spinner("Executing SQL query..."): # Create a connection and begin a transaction with st.session_state.engine.begin() as conn: # Execute the query using text() to ensure proper SQL compilation result = conn.execute(text(query)) # Convert the result to a pandas DataFrame df = pd.DataFrame(result.fetchall(), columns=result.keys()) execution_time = time.time() - start_time pprint(f"[Query Execution] Latency: {execution_time:.2f}s") return df except Exception as e: st.error(f"Error executing query: {str(e)}") return None def generate_sql_query(user_query): # Build context for all selected tables tables_context = [] for table_name, table_type in st.session_state.selected_tables.items(): # Format schema in markdown schema_info = st.session_state.table_schemas[table_name] # Build markdown formatted schema schema_md = [f"\n\n### {table_type}: {table_name}"] # Add table comment if exists if schema_info.get("table_comment"): schema_md.append(f"> {schema_info['table_comment']}") # Add column details schema_md.append("\n**Columns:**") for col_name, col_info in schema_info["columns"].items(): col_type = col_info["type"] col_comment = col_info.get("comment") # Format column with type and optional comment if col_comment: schema_md.append(f"- `{col_name}` ({col_type}) - {col_comment}") else: schema_md.append(f"- `{col_name}` ({col_type})") # Add sample data schema_md.append("\n**Sample Data:**") schema_md.append(st.session_state.sample_data[table_name].to_markdown(index=False)) # Join all parts with newlines tables_context.append("\n".join(schema_md)) prompt = f"""You are a SQL expert. Generate a valid PostgreSQL query based on the following context and user query. {chr(10).join(tables_context)} Important: 1. Only generate SELECT queries - no INSERT, UPDATE, DELETE, or other data modification statements 2. Only return the SQL query, nothing else 3. The query should be valid PostgreSQL syntax 4. Do not include any explanations or comments 5. Make sure to handle NULL values appropriately 6. If joining tables, use appropriate join conditions based on the schema 7. Use table names with appropriate qualifiers to avoid ambiguity User Query: {user_query} """ prompt_tokens = __countTokens(prompt) print("\n") pprint(f"[{MODEL}] Prompt tokens for SQL generation: {prompt_tokens}") # Debug prompt in a Streamlit expander for better organization # Check if running locally based on Streamlit's origin header if 'localhost' in st.context.headers.get("Origin", ""): with st.expander("Debug: Prompt Generation"): st.write(f"\nUser Query: {user_query}") st.write("\nFull Prompt:") st.code(prompt, language="text") start_time = time.time() with st.spinner(f"Generating SQL query using {MODEL}..."): if isClaudeModel: response = client.messages.create( model=MODEL, max_tokens=1000, messages=[ {"role": "user", "content": prompt}, ] ) raw_response = response.content[0].text else: response = client.chat.completions.create( model=MODEL, messages=[ {"role": "user", "content": prompt}, ] ) raw_response = response.choices[0].message.content generation_time = time.time() - start_time pprint(f"[{MODEL}] Query Generation Latency: {generation_time:.2f}s") return clean_sql_response(raw_response) # UI Components st.title("SQL Query Assistant") # Database Connection Section st.header("1. Database Connection") connection_string = st.text_input( "Enter PostgreSQL Connection String", value=st.session_state.connection_string if st.session_state.connection_string else "", type="password" ) if connection_string and connection_string != st.session_state.connection_string: if connect_to_db(connection_string): st.session_state.connection_string = connection_string st.success("Successfully connected to database!") # Table Selection Section if st.session_state.connection_string: st.header("2. Database Object Selection") inspector = inspect(st.session_state.engine) # Get both tables and views tables = inspector.get_table_names() views = inspector.get_view_names() # Create a list of tuples with (name, type) for all database objects db_objects = [(table, 'Table') for table in tables] + [(view, 'View') for view in views] db_objects.sort(key=lambda x: x[0]) # Sort alphabetically by name # Extract just the names for the multiselect object_names = [obj[0] for obj in db_objects] # Default to 'lsq_leads' if present default_selections = ['lsq_leads'] if 'lsq_leads' in object_names else [] # Create multiselect for table/view selection selected_objects = st.multiselect( "Select tables/views", options=object_names, default=default_selections, help="You can select multiple tables/views to query across them" ) # Display selected object types if selected_objects: st.write("Selected objects:") for obj in selected_objects: obj_type = next(obj_type for obj_name, obj_type in db_objects if obj_name == obj) st.write(f"- {obj}: {obj_type}") # Create containers for schema and data schema_container = st.container() data_container = st.container() # Initialize or reset session state for selected objects if selected_objects: # Always ensure dictionaries exist in session state if not isinstance(st.session_state.get("selected_tables"), dict): st.session_state.selected_tables = {} if not isinstance(st.session_state.get("table_schemas"), dict): st.session_state.table_schemas = {} if not isinstance(st.session_state.get("sample_data"), dict): st.session_state.sample_data = {} # Clear previous data for tables that are no longer selected current_tables = set(selected_objects) previous_tables = set(st.session_state.selected_tables.keys()) removed_tables = previous_tables - current_tables for table in removed_tables: if table in st.session_state.selected_tables: del st.session_state.selected_tables[table] if table in st.session_state.table_schemas: del st.session_state.table_schemas[table] if table in st.session_state.sample_data: del st.session_state.sample_data[table] # Update session state with new selections for obj in selected_objects: # Update selected tables st.session_state.selected_tables[obj] = next( obj_type for obj_name, obj_type in db_objects if obj_name == obj ) # Fetch and store schema schema = get_table_schema(obj) if schema: st.session_state.table_schemas[obj] = schema # Fetch and store sample data sample_data = get_sample_data(obj) if not sample_data.empty: st.session_state.sample_data[obj] = sample_data # Display schema and sample data for each selected object with schema_container: st.subheader("Table/View Schemas") for obj in selected_objects: if obj in st.session_state.table_schemas: st.write(f"**{obj} Schema:**") st.json(st.session_state.table_schemas[obj]) st.write("---") else: st.warning(f"Could not fetch schema for {obj}") with data_container: st.subheader("Sample Data") for obj in selected_objects: if obj in st.session_state.sample_data and not st.session_state.sample_data[obj].empty: st.write(f"**{obj} (Last 3 rows):**") st.dataframe( st.session_state.sample_data[obj], use_container_width=True, hide_index=True ) st.write("---") else: st.warning(f"No sample data available for {obj}") # Query Input Section if st.session_state.get("selected_tables"): st.header("3. Query Input") user_query = st.text_area("Enter your query in plain English") if st.button("Generate and Execute Query"): if user_query: # Generate SQL query sql_query = generate_sql_query(user_query) # Display the generated query st.subheader("Generated SQL Query") st.code(sql_query, language="sql") # Execute the query results = execute_query(sql_query) if results is not None: st.subheader("Query Results") st.dataframe(results)