query-sql / app.py
Ashhar
debug logs shown only in localhost
af1fee3
raw
history blame
13.1 kB
import streamlit as st
import os
import pandas as pd
from typing import Literal, TypedDict
from sqlalchemy import create_engine, inspect, text
import json
from transformers import AutoTokenizer
from utils import pprint
import time
import re
from openai import OpenAI
import anthropic
from clients.openRouter import OpenRouter
# Load environment variables
from dotenv import load_dotenv
load_dotenv()
# Set up page configuration
st.set_page_config(
page_title="SQL Query Assistant",
page_icon="💾",
layout="centered",
initial_sidebar_state="collapsed"
)
ModelType = Literal["GPT_4o", "GPT_o1", "CLAUDE", "LLAMA", "DEEPSEEK", "DEEPSEEK_R1", "DEEPSEEK_R1_DISTILL"]
ModelConfig = TypedDict("ModelConfig", {
"client": OpenAI | anthropic.Anthropic,
"model": str,
"max_context": int,
"tokenizer": AutoTokenizer
})
MODEL_CONFIG: dict[ModelType, ModelConfig] = {
"CLAUDE": {
"client": anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")),
"model": "claude-3-5-haiku-20241022",
# "model": "claude-3-5-sonnet-20241022",
# "model": "claude-3-5-sonnet-20240620",
"max_context": 40000,
"tokenizer": AutoTokenizer.from_pretrained("Xenova/claude-tokenizer")
},
"GPT_4o": {
"client": OpenAI(api_key=os.environ.get("OPENAI_API_KEY")),
"model": "gpt-4o",
"max_context": 15000,
"tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o")
},
# "GPT_o1": {
# "client": OpenAI(api_key=os.environ.get("OPENAI_API_KEY")),
# "model": "o1-preview",
# "max_context": 15000,
# "tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o")
# },
"DEEPSEEK": {
"client": OpenRouter(
api_key=os.environ.get("OPENROUTER_API_KEY"),
),
"model": "deepseek/deepseek-chat",
"max_context": 30000,
"tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o")
},
"DEEPSEEK_R1": {
"client": OpenRouter(
api_key=os.environ.get("OPENROUTER_API_KEY"),
),
"model": "deepseek/deepseek-r1",
"max_context": 30000,
"tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o")
},
}
def get_model_type():
"""
Get the model type from Streamlit sidebar with model names
"""
# Get the available model types from the MODEL_CONFIG keys
available_models = list(MODEL_CONFIG.keys())
# Create a list of display labels with just the model names
model_display_labels = [
MODEL_CONFIG[model_type]['model']
for model_type in available_models
]
# Add a sidebar selection for model name
selected_model_name = st.sidebar.selectbox(
"Select AI Model",
model_display_labels,
index=0
)
# Find the corresponding model type for the selected model name
selected_model_type = next(
model_type for model_type in available_models
if MODEL_CONFIG[model_type]['model'] == selected_model_name
)
return selected_model_type
# In the main application flow, replace the previous modelType assignment
modelType = get_model_type()
client = MODEL_CONFIG[modelType]["client"]
MODEL = MODEL_CONFIG[modelType]["model"]
TOOLS_MODEL = MODEL_CONFIG[modelType].get("tools_model") or MODEL
MAX_CONTEXT = MODEL_CONFIG[modelType]["max_context"]
tokenizer = MODEL_CONFIG[modelType]["tokenizer"]
isClaudeModel = modelType == "CLAUDE"
isDeepSeekModel = modelType.startswith("DEEPSEEK")
def __countTokens(text):
text = str(text)
tokens = tokenizer.encode(text, add_special_tokens=False)
return len(tokens)
# Initialize session state variables
if "ipAddress" not in st.session_state:
st.session_state.ipAddress = st.context.headers.get("x-forwarded-for")
if "connection_string" not in st.session_state:
st.session_state.connection_string = None
if "selected_table" not in st.session_state:
st.session_state.selected_table = None
if "table_schema" not in st.session_state:
st.session_state.table_schema = None
if "sample_data" not in st.session_state:
st.session_state.sample_data = None
if "engine" not in st.session_state:
st.session_state.engine = None
def connect_to_db(connection_string):
try:
engine = create_engine(connection_string)
# Test the connection
with engine.connect():
pass
st.session_state.engine = engine
return True
except Exception as e:
st.error(f"Failed to connect to database: {str(e)}")
return False
def get_table_schema(table_name):
if not st.session_state.engine:
return None
inspector = inspect(st.session_state.engine)
columns = inspector.get_columns(table_name)
schema = {col['name']: str(col['type']) for col in columns}
# Get table comment
table_comment_query = """
SELECT obj_description(c.oid) as table_comment
FROM pg_class c
JOIN pg_namespace n ON n.oid = c.relnamespace
WHERE c.relname = :table_name
AND n.nspname = 'public'
"""
# Get column comments
column_comments_query = """
SELECT
cols.column_name,
(
SELECT pg_catalog.col_description(c.oid, cols.ordinal_position::int)
FROM pg_catalog.pg_class c
WHERE c.oid = (SELECT ('"' || cols.table_name || '"')::regclass::oid)
AND c.relname = cols.table_name
) as column_comment
FROM information_schema.columns cols
WHERE cols.table_name = :table_name
AND cols.table_schema = 'public'
"""
try:
with st.session_state.engine.connect() as conn:
# Get table comment
table_comment_result = conn.execute(text(table_comment_query), {"table_name": table_name}).fetchone()
table_comment = table_comment_result[0] if table_comment_result else None
# Get column comments
column_comments_result = conn.execute(text(column_comments_query), {"table_name": table_name}).fetchall()
column_comments = {row[0]: row[1] for row in column_comments_result}
# Create enhanced schema dictionary
enhanced_schema = {
"table_comment": table_comment,
"columns": {
col_name: {
"type": schema[col_name],
"comment": column_comments.get(col_name)
}
for col_name in schema
}
}
return enhanced_schema
except Exception as e:
st.error(f"Error fetching schema comments: {str(e)}")
return schema # Fallback to basic schema if comment retrieval fails
def get_sample_data(table_name):
if not st.session_state.engine:
return None
query = f"SELECT * FROM {table_name} ORDER BY 1 DESC LIMIT 3"
try:
with st.session_state.engine.connect() as conn:
df = pd.read_sql(query, conn)
return df
except Exception as e:
st.error(f"Error fetching sample data: {str(e)}")
return None
def clean_sql_response(response: str) -> str:
"""Extract clean SQL query from a potentially formatted response."""
# If response contains SQL code block, extract it
sql_block_match = re.search(r'```sql\n(.*?)\n```', response, re.DOTALL)
if sql_block_match:
return sql_block_match.group(1).strip()
return response.strip()
def execute_query(query):
if not st.session_state.engine:
return None
try:
start_time = time.time()
with st.spinner("Executing SQL query..."):
with st.session_state.engine.connect() as conn:
df = pd.read_sql(query, conn)
execution_time = time.time() - start_time
pprint(f"[Query Execution] Latency: {execution_time:.2f}s")
return df
except Exception as e:
st.error(f"Error executing query: {str(e)}")
return None
def generate_sql_query(user_query):
prompt = f"""You are a SQL expert. Generate a valid PostgreSQL query based on the following context and user query.
Table Name: {st.session_state.selected_table}
Table Schema:
{json.dumps(st.session_state.table_schema, indent=2)}
Sample Data:
{st.session_state.sample_data.to_markdown(index=False)}
Important:
1. Only return the SQL query, nothing else
2. The query should be valid PostgreSQL syntax
3. Do not include any explanations or comments
4. Make sure to handle NULL values appropriately
5. Use the table name '{st.session_state.selected_table}' in your query
User Query: {user_query}
"""
prompt_tokens = __countTokens(prompt)
pprint(f"\n[{MODEL}] Prompt tokens for SQL generation: {prompt_tokens}")
# Debug prompt in a Streamlit expander for better organization
# Check if running locally based on Streamlit's origin header
if 'localhost' in st.context.headers.get("Origin", ""):
with st.expander("Debug: Prompt Generation"):
st.write(f"\nUser Query: {user_query}")
st.write("\nFull Prompt:")
st.code(prompt, language="text")
start_time = time.time()
with st.spinner(f"Generating SQL query using {MODEL}..."):
if isClaudeModel:
response = client.messages.create(
model=MODEL,
max_tokens=1000,
messages=[
{"role": "user", "content": prompt},
]
)
raw_response = response.content[0].text
else:
response = client.chat.completions.create(
model=MODEL,
messages=[
{"role": "user", "content": prompt},
]
)
raw_response = response.choices[0].message.content
generation_time = time.time() - start_time
pprint(f"[{MODEL}] Query Generation Latency: {generation_time:.2f}s")
return clean_sql_response(raw_response)
# UI Components
st.title("SQL Query Assistant")
# Database Connection Section
st.header("1. Database Connection")
connection_string = st.text_input(
"Enter PostgreSQL Connection String",
value=st.session_state.connection_string if st.session_state.connection_string else "",
type="password"
)
if connection_string and connection_string != st.session_state.connection_string:
if connect_to_db(connection_string):
st.session_state.connection_string = connection_string
st.success("Successfully connected to database!")
# Table Selection Section
if st.session_state.connection_string:
st.header("2. Table Selection")
inspector = inspect(st.session_state.engine)
tables = inspector.get_table_names()
# Set default index to 'lsq_leads' if present, otherwise 0
default_index = tables.index('lsq_leads') if 'lsq_leads' in tables else 0
selected_table = st.selectbox("Select a table", tables, index=default_index)
# Create containers for schema and data
schema_container = st.container()
data_container = st.container()
# Always load table data if we have a selected table
if selected_table:
# Update session state
if selected_table != st.session_state.selected_table:
st.session_state.selected_table = selected_table
# Always fetch schema and sample data
st.session_state.table_schema = get_table_schema(selected_table)
st.session_state.sample_data = get_sample_data(selected_table)
# Always display schema and sample data if available
with schema_container:
if st.session_state.table_schema:
st.subheader("Table Schema")
# Force immediate rendering with an empty element
st.empty()
st.json(st.session_state.table_schema)
with data_container:
if st.session_state.sample_data is not None:
st.subheader("Sample Data (Last 3 rows)")
# Force immediate rendering with an empty element
st.empty()
st.dataframe(
st.session_state.sample_data,
use_container_width=True,
hide_index=True
)
# Query Input Section
if st.session_state.selected_table:
st.header("3. Query Input")
user_query = st.text_area("Enter your query in plain English")
if st.button("Generate and Execute Query"):
if user_query:
# Generate SQL query
sql_query = generate_sql_query(user_query)
# Display the generated query
st.subheader("Generated SQL Query")
st.code(sql_query, language="sql")
# Execute the query
results = execute_query(sql_query)
if results is not None:
st.subheader("Query Results")
st.dataframe(results)