import streamlit as st import os import pandas as pd from typing import Literal, TypedDict from sqlalchemy import create_engine, inspect, text import json from transformers import AutoTokenizer from utils import pprint import time import re from openai import OpenAI import anthropic from clients.openRouter import OpenRouter # Load environment variables from dotenv import load_dotenv load_dotenv() # Set up page configuration st.set_page_config( page_title="SQL Query Assistant", page_icon="💾", layout="centered", initial_sidebar_state="collapsed" ) ModelType = Literal["GPT_4o", "GPT_o1", "CLAUDE", "LLAMA", "DEEPSEEK", "DEEPSEEK_R1", "DEEPSEEK_R1_DISTILL"] ModelConfig = TypedDict("ModelConfig", { "client": OpenAI | anthropic.Anthropic, "model": str, "max_context": int, "tokenizer": AutoTokenizer }) MODEL_CONFIG: dict[ModelType, ModelConfig] = { "CLAUDE": { "client": anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")), "model": "claude-3-5-haiku-20241022", # "model": "claude-3-5-sonnet-20241022", # "model": "claude-3-5-sonnet-20240620", "max_context": 40000, "tokenizer": AutoTokenizer.from_pretrained("Xenova/claude-tokenizer") }, "GPT_4o": { "client": OpenAI(api_key=os.environ.get("OPENAI_API_KEY")), "model": "gpt-4o", "max_context": 15000, "tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o") }, # "GPT_o1": { # "client": OpenAI(api_key=os.environ.get("OPENAI_API_KEY")), # "model": "o1-preview", # "max_context": 15000, # "tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o") # }, "DEEPSEEK": { "client": OpenRouter( api_key=os.environ.get("OPENROUTER_API_KEY"), ), "model": "deepseek/deepseek-chat", "max_context": 30000, "tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o") }, "DEEPSEEK_R1": { "client": OpenRouter( api_key=os.environ.get("OPENROUTER_API_KEY"), ), "model": "deepseek/deepseek-r1", "max_context": 30000, "tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o") }, } def get_model_type(): """ Get the model type from Streamlit sidebar with model names """ # Get the available model types from the MODEL_CONFIG keys available_models = list(MODEL_CONFIG.keys()) # Create a list of display labels with just the model names model_display_labels = [ MODEL_CONFIG[model_type]['model'] for model_type in available_models ] # Add a sidebar selection for model name selected_model_name = st.sidebar.selectbox( "Select AI Model", model_display_labels, index=0 ) # Find the corresponding model type for the selected model name selected_model_type = next( model_type for model_type in available_models if MODEL_CONFIG[model_type]['model'] == selected_model_name ) return selected_model_type # In the main application flow, replace the previous modelType assignment modelType = get_model_type() client = MODEL_CONFIG[modelType]["client"] MODEL = MODEL_CONFIG[modelType]["model"] TOOLS_MODEL = MODEL_CONFIG[modelType].get("tools_model") or MODEL MAX_CONTEXT = MODEL_CONFIG[modelType]["max_context"] tokenizer = MODEL_CONFIG[modelType]["tokenizer"] isClaudeModel = modelType == "CLAUDE" isDeepSeekModel = modelType.startswith("DEEPSEEK") def __countTokens(text): text = str(text) tokens = tokenizer.encode(text, add_special_tokens=False) return len(tokens) # Initialize session state variables if "ipAddress" not in st.session_state: st.session_state.ipAddress = st.context.headers.get("x-forwarded-for") if "connection_string" not in st.session_state: st.session_state.connection_string = None if "selected_table" not in st.session_state: st.session_state.selected_table = None if "table_schema" not in st.session_state: st.session_state.table_schema = None if "sample_data" not in st.session_state: st.session_state.sample_data = None if "engine" not in st.session_state: st.session_state.engine = None def connect_to_db(connection_string): try: engine = create_engine(connection_string) # Test the connection with engine.connect(): pass st.session_state.engine = engine return True except Exception as e: st.error(f"Failed to connect to database: {str(e)}") return False def get_table_schema(table_name): if not st.session_state.engine: return None inspector = inspect(st.session_state.engine) columns = inspector.get_columns(table_name) schema = {col['name']: str(col['type']) for col in columns} # Get table comment table_comment_query = """ SELECT obj_description(c.oid) as table_comment FROM pg_class c JOIN pg_namespace n ON n.oid = c.relnamespace WHERE c.relname = :table_name AND n.nspname = 'public' """ # Get column comments column_comments_query = """ SELECT cols.column_name, ( SELECT pg_catalog.col_description(c.oid, cols.ordinal_position::int) FROM pg_catalog.pg_class c WHERE c.oid = (SELECT ('"' || cols.table_name || '"')::regclass::oid) AND c.relname = cols.table_name ) as column_comment FROM information_schema.columns cols WHERE cols.table_name = :table_name AND cols.table_schema = 'public' """ try: with st.session_state.engine.connect() as conn: # Get table comment table_comment_result = conn.execute(text(table_comment_query), {"table_name": table_name}).fetchone() table_comment = table_comment_result[0] if table_comment_result else None # Get column comments column_comments_result = conn.execute(text(column_comments_query), {"table_name": table_name}).fetchall() column_comments = {row[0]: row[1] for row in column_comments_result} # Create enhanced schema dictionary enhanced_schema = { "table_comment": table_comment, "columns": { col_name: { "type": schema[col_name], "comment": column_comments.get(col_name) } for col_name in schema } } return enhanced_schema except Exception as e: st.error(f"Error fetching schema comments: {str(e)}") return schema # Fallback to basic schema if comment retrieval fails def get_sample_data(table_name): if not st.session_state.engine: return None query = f"SELECT * FROM {table_name} ORDER BY 1 DESC LIMIT 3" try: with st.session_state.engine.connect() as conn: df = pd.read_sql(query, conn) return df except Exception as e: st.error(f"Error fetching sample data: {str(e)}") return None def clean_sql_response(response: str) -> str: """Extract clean SQL query from a potentially formatted response.""" # If response contains SQL code block, extract it sql_block_match = re.search(r'```sql\n(.*?)\n```', response, re.DOTALL) if sql_block_match: return sql_block_match.group(1).strip() return response.strip() def execute_query(query): if not st.session_state.engine: return None try: start_time = time.time() with st.spinner("Executing SQL query..."): with st.session_state.engine.connect() as conn: df = pd.read_sql(query, conn) execution_time = time.time() - start_time pprint(f"[Query Execution] Latency: {execution_time:.2f}s") return df except Exception as e: st.error(f"Error executing query: {str(e)}") return None def generate_sql_query(user_query): prompt = f"""You are a SQL expert. Generate a valid PostgreSQL query based on the following context and user query. Table Name: {st.session_state.selected_table} Table Schema: {json.dumps(st.session_state.table_schema, indent=2)} Sample Data: {st.session_state.sample_data.to_markdown(index=False)} Important: 1. Only return the SQL query, nothing else 2. The query should be valid PostgreSQL syntax 3. Do not include any explanations or comments 4. Make sure to handle NULL values appropriately 5. Use the table name '{st.session_state.selected_table}' in your query User Query: {user_query} """ prompt_tokens = __countTokens(prompt) pprint(f"\n[{MODEL}] Prompt tokens for SQL generation: {prompt_tokens}") # Debug prompt in a Streamlit expander for better organization # Check if running locally based on Streamlit's origin header if 'localhost' in st.context.headers.get("Origin", ""): with st.expander("Debug: Prompt Generation"): st.write(f"\nUser Query: {user_query}") st.write("\nFull Prompt:") st.code(prompt, language="text") start_time = time.time() with st.spinner(f"Generating SQL query using {MODEL}..."): if isClaudeModel: response = client.messages.create( model=MODEL, max_tokens=1000, messages=[ {"role": "user", "content": prompt}, ] ) raw_response = response.content[0].text else: response = client.chat.completions.create( model=MODEL, messages=[ {"role": "user", "content": prompt}, ] ) raw_response = response.choices[0].message.content generation_time = time.time() - start_time pprint(f"[{MODEL}] Query Generation Latency: {generation_time:.2f}s") return clean_sql_response(raw_response) # UI Components st.title("SQL Query Assistant") # Database Connection Section st.header("1. Database Connection") connection_string = st.text_input( "Enter PostgreSQL Connection String", value=st.session_state.connection_string if st.session_state.connection_string else "", type="password" ) if connection_string and connection_string != st.session_state.connection_string: if connect_to_db(connection_string): st.session_state.connection_string = connection_string st.success("Successfully connected to database!") # Table Selection Section if st.session_state.connection_string: st.header("2. Table Selection") inspector = inspect(st.session_state.engine) tables = inspector.get_table_names() # Set default index to 'lsq_leads' if present, otherwise 0 default_index = tables.index('lsq_leads') if 'lsq_leads' in tables else 0 selected_table = st.selectbox("Select a table", tables, index=default_index) # Create containers for schema and data schema_container = st.container() data_container = st.container() # Always load table data if we have a selected table if selected_table: # Update session state if selected_table != st.session_state.selected_table: st.session_state.selected_table = selected_table # Always fetch schema and sample data st.session_state.table_schema = get_table_schema(selected_table) st.session_state.sample_data = get_sample_data(selected_table) # Always display schema and sample data if available with schema_container: if st.session_state.table_schema: st.subheader("Table Schema") # Force immediate rendering with an empty element st.empty() st.json(st.session_state.table_schema) with data_container: if st.session_state.sample_data is not None: st.subheader("Sample Data (Last 3 rows)") # Force immediate rendering with an empty element st.empty() st.dataframe( st.session_state.sample_data, use_container_width=True, hide_index=True ) # Query Input Section if st.session_state.selected_table: st.header("3. Query Input") user_query = st.text_area("Enter your query in plain English") if st.button("Generate and Execute Query"): if user_query: # Generate SQL query sql_query = generate_sql_query(user_query) # Display the generated query st.subheader("Generated SQL Query") st.code(sql_query, language="sql") # Execute the query results = execute_query(sql_query) if results is not None: st.subheader("Query Results") st.dataframe(results)