query-sql / app.py
Ashhar
restrict non-readonly queries
a2e9487
import streamlit as st
import os
import pandas as pd
from typing import Literal, TypedDict
from sqlalchemy import create_engine, inspect, text
from transformers import AutoTokenizer
from utils import pprint
import time
import re
from openai import OpenAI
import anthropic
from clients.openRouter import OpenRouter
# Load environment variables
from dotenv import load_dotenv
load_dotenv()
# Set up page configuration
st.set_page_config(
page_title="SQL Query Assistant",
page_icon="💾",
layout="centered",
initial_sidebar_state="collapsed"
)
ModelType = Literal["GPT_4o", "GPT_o1", "CLAUDE", "LLAMA", "DEEPSEEK", "DEEPSEEK_R1", "DEEPSEEK_R1_DISTILL"]
ModelConfig = TypedDict("ModelConfig", {
"client": OpenAI | anthropic.Anthropic,
"model": str,
"max_context": int,
"tokenizer": AutoTokenizer
})
MODEL_CONFIG: dict[ModelType, ModelConfig] = {
"CLAUDE_HAIKU": {
"client": anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")),
"model": "claude-3-5-haiku-20241022",
# "model": "claude-3-5-sonnet-20241022",
# "model": "claude-3-5-sonnet-20240620",
"max_context": 40000,
"tokenizer": AutoTokenizer.from_pretrained("Xenova/claude-tokenizer")
},
"CLAUDE_SONNET": {
"client": anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")),
# "model": "claude-3-5-haiku-20241022",
# "model": "claude-3-5-sonnet-20241022",
"model": "claude-3-5-sonnet-20240620",
"max_context": 40000,
"tokenizer": AutoTokenizer.from_pretrained("Xenova/claude-tokenizer")
},
"GPT_4o": {
"client": OpenAI(api_key=os.environ.get("OPENAI_API_KEY")),
"model": "gpt-4o",
"max_context": 15000,
"tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o")
},
# "GPT_o1": {
# "client": OpenAI(api_key=os.environ.get("OPENAI_API_KEY")),
# "model": "o1-preview",
# "max_context": 15000,
# "tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o")
# },
"DEEPSEEK": {
"client": OpenRouter(
api_key=os.environ.get("OPENROUTER_API_KEY"),
),
"model": "deepseek/deepseek-chat",
"max_context": 30000,
"tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o")
},
"DEEPSEEK_R1": {
"client": OpenRouter(
api_key=os.environ.get("OPENROUTER_API_KEY"),
),
"model": "deepseek/deepseek-r1",
"max_context": 30000,
"tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o")
},
}
def get_model_type():
"""
Get the model type from Streamlit sidebar with model names
"""
# Get the available model types from the MODEL_CONFIG keys
available_models = list(MODEL_CONFIG.keys())
# Create a list of display labels with just the model names
model_display_labels = [
MODEL_CONFIG[model_type]['model']
for model_type in available_models
]
# Add a sidebar selection for model name
selected_model_name = st.sidebar.selectbox(
"Select AI Model",
model_display_labels,
index=0
)
# Find the corresponding model type for the selected model name
selected_model_type = next(
model_type for model_type in available_models
if MODEL_CONFIG[model_type]['model'] == selected_model_name
)
return selected_model_type
# In the main application flow, replace the previous modelType assignment
modelType = get_model_type()
client = MODEL_CONFIG[modelType]["client"]
MODEL = MODEL_CONFIG[modelType]["model"]
TOOLS_MODEL = MODEL_CONFIG[modelType].get("tools_model") or MODEL
MAX_CONTEXT = MODEL_CONFIG[modelType]["max_context"]
tokenizer = MODEL_CONFIG[modelType]["tokenizer"]
isClaudeModel = modelType.startswith("CLAUDE")
isDeepSeekModel = modelType.startswith("DEEPSEEK")
def __countTokens(text):
text = str(text)
tokens = tokenizer.encode(text, add_special_tokens=False)
return len(tokens)
# Initialize session state variables
if "ipAddress" not in st.session_state:
st.session_state.ipAddress = st.context.headers.get("x-forwarded-for")
if "connection_string" not in st.session_state:
st.session_state.connection_string = None
if "selected_table" not in st.session_state:
st.session_state.selected_table = None
if "table_schema" not in st.session_state:
st.session_state.table_schema = None
if "sample_data" not in st.session_state:
st.session_state.sample_data = None
if "engine" not in st.session_state:
st.session_state.engine = None
def connect_to_db(connection_string):
try:
engine = create_engine(connection_string)
# Test the connection
with engine.connect():
pass
st.session_state.engine = engine
return True
except Exception as e:
st.error(f"Failed to connect to database: {str(e)}")
return False
def get_table_schema(table_name):
if not st.session_state.engine:
return None
inspector = inspect(st.session_state.engine)
columns = inspector.get_columns(table_name)
schema = {col['name']: str(col['type']) for col in columns}
# Get table comment
table_comment_query = """
SELECT obj_description(c.oid) as table_comment
FROM pg_class c
JOIN pg_namespace n ON n.oid = c.relnamespace
WHERE c.relname = :table_name
AND n.nspname = 'public'
"""
# Get column comments
column_comments_query = """
SELECT
cols.column_name,
(
SELECT pg_catalog.col_description(c.oid, cols.ordinal_position::int)
FROM pg_catalog.pg_class c
WHERE c.oid = (SELECT ('"' || cols.table_name || '"')::regclass::oid)
AND c.relname = cols.table_name
) as column_comment
FROM information_schema.columns cols
WHERE cols.table_name = :table_name
AND cols.table_schema = 'public'
"""
try:
with st.session_state.engine.connect() as conn:
# Get table comment
table_comment_result = conn.execute(text(table_comment_query), {"table_name": table_name}).fetchone()
table_comment = table_comment_result[0] if table_comment_result else None
# Get column comments
column_comments_result = conn.execute(text(column_comments_query), {"table_name": table_name}).fetchall()
column_comments = {row[0]: row[1] for row in column_comments_result}
# Create enhanced schema dictionary
enhanced_schema = {
"table_comment": table_comment,
"columns": {
col_name: {
"type": schema[col_name],
"comment": column_comments.get(col_name)
}
for col_name in schema
}
}
return enhanced_schema
except Exception as e:
st.error(f"Error fetching schema comments: {str(e)}")
return schema # Fallback to basic schema if comment retrieval fails
def get_sample_data(table_name):
if not st.session_state.engine:
return pd.DataFrame() # Return empty DataFrame instead of None
query = f"SELECT * FROM {table_name} ORDER BY 1 DESC LIMIT 3"
try:
with st.session_state.engine.connect() as conn:
df = pd.read_sql(query, conn)
return df
except Exception as e:
st.error(f"Error fetching sample data for {table_name}: {str(e)}")
return pd.DataFrame() # Return empty DataFrame on error
def clean_sql_response(response: str) -> str:
"""Extract clean SQL query from a potentially formatted response."""
# If response contains SQL code block, extract it
sql_block_match = re.search(r'```sql\n(.*?)\n```', response, re.DOTALL)
if sql_block_match:
return sql_block_match.group(1).strip()
return response.strip()
def is_read_only_query(query: str) -> bool:
"""Check if the query is read-only (SELECT only)."""
# Convert query to uppercase for case-insensitive comparison
query_upper = query.upper()
# List of SQL statements that modify data
modification_statements = [
'INSERT', 'UPDATE', 'DELETE', 'DROP', 'CREATE', 'ALTER', 'TRUNCATE',
'REPLACE', 'MERGE', 'UPSERT', 'GRANT', 'REVOKE'
]
# Check if query starts with any modification statement
return not any(query_upper.strip().startswith(stmt) for stmt in modification_statements)
def execute_query(query):
if not st.session_state.engine:
return None
# Check if the query is read-only
if not is_read_only_query(query):
st.error("Error: Only SELECT queries are allowed for security reasons.")
return None
try:
start_time = time.time()
with st.spinner("Executing SQL query..."):
# Create a connection and begin a transaction
with st.session_state.engine.begin() as conn:
# Execute the query using text() to ensure proper SQL compilation
result = conn.execute(text(query))
# Convert the result to a pandas DataFrame
df = pd.DataFrame(result.fetchall(), columns=result.keys())
execution_time = time.time() - start_time
pprint(f"[Query Execution] Latency: {execution_time:.2f}s")
return df
except Exception as e:
st.error(f"Error executing query: {str(e)}")
return None
def generate_sql_query(user_query):
# Build context for all selected tables
tables_context = []
for table_name, table_type in st.session_state.selected_tables.items():
# Format schema in markdown
schema_info = st.session_state.table_schemas[table_name]
# Build markdown formatted schema
schema_md = [f"\n\n### {table_type}: {table_name}"]
# Add table comment if exists
if schema_info.get("table_comment"):
schema_md.append(f"> {schema_info['table_comment']}")
# Add column details
schema_md.append("\n**Columns:**")
for col_name, col_info in schema_info["columns"].items():
col_type = col_info["type"]
col_comment = col_info.get("comment")
# Format column with type and optional comment
if col_comment:
schema_md.append(f"- `{col_name}` ({col_type}) - {col_comment}")
else:
schema_md.append(f"- `{col_name}` ({col_type})")
# Add sample data
schema_md.append("\n**Sample Data:**")
schema_md.append(st.session_state.sample_data[table_name].to_markdown(index=False))
# Join all parts with newlines
tables_context.append("\n".join(schema_md))
prompt = f"""You are a SQL expert. Generate a valid PostgreSQL query based on the following context and user query.
<AVAILABLE_OBJECTS>
{chr(10).join(tables_context)}
Important:
1. Only generate SELECT queries - no INSERT, UPDATE, DELETE, or other data modification statements
2. Only return the SQL query, nothing else
3. The query should be valid PostgreSQL syntax
4. Do not include any explanations or comments
5. Make sure to handle NULL values appropriately
6. If joining tables, use appropriate join conditions based on the schema
7. Use table names with appropriate qualifiers to avoid ambiguity
User Query: {user_query}
"""
prompt_tokens = __countTokens(prompt)
print("\n")
pprint(f"[{MODEL}] Prompt tokens for SQL generation: {prompt_tokens}")
# Debug prompt in a Streamlit expander for better organization
# Check if running locally based on Streamlit's origin header
if 'localhost' in st.context.headers.get("Origin", ""):
with st.expander("Debug: Prompt Generation"):
st.write(f"\nUser Query: {user_query}")
st.write("\nFull Prompt:")
st.code(prompt, language="text")
start_time = time.time()
with st.spinner(f"Generating SQL query using {MODEL}..."):
if isClaudeModel:
response = client.messages.create(
model=MODEL,
max_tokens=1000,
messages=[
{"role": "user", "content": prompt},
]
)
raw_response = response.content[0].text
else:
response = client.chat.completions.create(
model=MODEL,
messages=[
{"role": "user", "content": prompt},
]
)
raw_response = response.choices[0].message.content
generation_time = time.time() - start_time
pprint(f"[{MODEL}] Query Generation Latency: {generation_time:.2f}s")
return clean_sql_response(raw_response)
# UI Components
st.title("SQL Query Assistant")
# Database Connection Section
st.header("1. Database Connection")
connection_string = st.text_input(
"Enter PostgreSQL Connection String",
value=st.session_state.connection_string if st.session_state.connection_string else "",
type="password"
)
if connection_string and connection_string != st.session_state.connection_string:
if connect_to_db(connection_string):
st.session_state.connection_string = connection_string
st.success("Successfully connected to database!")
# Table Selection Section
if st.session_state.connection_string:
st.header("2. Database Object Selection")
inspector = inspect(st.session_state.engine)
# Get both tables and views
tables = inspector.get_table_names()
views = inspector.get_view_names()
# Create a list of tuples with (name, type) for all database objects
db_objects = [(table, 'Table') for table in tables] + [(view, 'View') for view in views]
db_objects.sort(key=lambda x: x[0]) # Sort alphabetically by name
# Extract just the names for the multiselect
object_names = [obj[0] for obj in db_objects]
# Default to 'lsq_leads' if present
default_selections = ['lsq_leads'] if 'lsq_leads' in object_names else []
# Create multiselect for table/view selection
selected_objects = st.multiselect(
"Select tables/views",
options=object_names,
default=default_selections,
help="You can select multiple tables/views to query across them"
)
# Display selected object types
if selected_objects:
st.write("Selected objects:")
for obj in selected_objects:
obj_type = next(obj_type for obj_name, obj_type in db_objects if obj_name == obj)
st.write(f"- {obj}: {obj_type}")
# Create containers for schema and data
schema_container = st.container()
data_container = st.container()
# Initialize or reset session state for selected objects
if selected_objects:
# Always ensure dictionaries exist in session state
if not isinstance(st.session_state.get("selected_tables"), dict):
st.session_state.selected_tables = {}
if not isinstance(st.session_state.get("table_schemas"), dict):
st.session_state.table_schemas = {}
if not isinstance(st.session_state.get("sample_data"), dict):
st.session_state.sample_data = {}
# Clear previous data for tables that are no longer selected
current_tables = set(selected_objects)
previous_tables = set(st.session_state.selected_tables.keys())
removed_tables = previous_tables - current_tables
for table in removed_tables:
if table in st.session_state.selected_tables:
del st.session_state.selected_tables[table]
if table in st.session_state.table_schemas:
del st.session_state.table_schemas[table]
if table in st.session_state.sample_data:
del st.session_state.sample_data[table]
# Update session state with new selections
for obj in selected_objects:
# Update selected tables
st.session_state.selected_tables[obj] = next(
obj_type for obj_name, obj_type in db_objects if obj_name == obj
)
# Fetch and store schema
schema = get_table_schema(obj)
if schema:
st.session_state.table_schemas[obj] = schema
# Fetch and store sample data
sample_data = get_sample_data(obj)
if not sample_data.empty:
st.session_state.sample_data[obj] = sample_data
# Display schema and sample data for each selected object
with schema_container:
st.subheader("Table/View Schemas")
for obj in selected_objects:
if obj in st.session_state.table_schemas:
st.write(f"**{obj} Schema:**")
st.json(st.session_state.table_schemas[obj])
st.write("---")
else:
st.warning(f"Could not fetch schema for {obj}")
with data_container:
st.subheader("Sample Data")
for obj in selected_objects:
if obj in st.session_state.sample_data and not st.session_state.sample_data[obj].empty:
st.write(f"**{obj} (Last 3 rows):**")
st.dataframe(
st.session_state.sample_data[obj],
use_container_width=True,
hide_index=True
)
st.write("---")
else:
st.warning(f"No sample data available for {obj}")
# Query Input Section
if st.session_state.get("selected_tables"):
st.header("3. Query Input")
user_query = st.text_area("Enter your query in plain English")
if st.button("Generate and Execute Query"):
if user_query:
# Generate SQL query
sql_query = generate_sql_query(user_query)
# Display the generated query
st.subheader("Generated SQL Query")
st.code(sql_query, language="sql")
# Execute the query
results = execute_query(sql_query)
if results is not None:
st.subheader("Query Results")
st.dataframe(results)