query-sql / app.py
Ashhar
logging format
557b97f
raw
history blame
11 kB
import streamlit as st
import os
import pandas as pd
from typing import Literal, TypedDict
from sqlalchemy import create_engine, inspect
import json
from transformers import AutoTokenizer
from utils import pprint
import time
import re
from openai import OpenAI
import anthropic
from clients.openRouter import OpenRouter
# Load environment variables
from dotenv import load_dotenv
load_dotenv()
# Set up page configuration
st.set_page_config(
page_title="SQL Query Assistant",
page_icon="💾",
layout="centered",
initial_sidebar_state="collapsed"
)
ModelType = Literal["GPT_4o", "GPT_o1", "CLAUDE", "LLAMA", "DEEPSEEK", "DEEPSEEK_R1", "DEEPSEEK_R1_DISTILL"]
ModelConfig = TypedDict("ModelConfig", {
"client": OpenAI | anthropic.Anthropic,
"model": str,
"max_context": int,
"tokenizer": AutoTokenizer
})
MODEL_CONFIG: dict[ModelType, ModelConfig] = {
"CLAUDE": {
"client": anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")),
"model": "claude-3-5-haiku-20241022",
# "model": "claude-3-5-sonnet-20241022",
# "model": "claude-3-5-sonnet-20240620",
"max_context": 40000,
"tokenizer": AutoTokenizer.from_pretrained("Xenova/claude-tokenizer")
},
"GPT_4o": {
"client": OpenAI(api_key=os.environ.get("OPENAI_API_KEY")),
"model": "gpt-4o",
"max_context": 15000,
"tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o")
},
# "GPT_o1": {
# "client": OpenAI(api_key=os.environ.get("OPENAI_API_KEY")),
# "model": "o1-preview",
# "max_context": 15000,
# "tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o")
# },
"DEEPSEEK": {
"client": OpenRouter(
api_key=os.environ.get("OPENROUTER_API_KEY"),
),
"model": "deepseek/deepseek-chat",
"max_context": 30000,
"tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o")
},
"DEEPSEEK_R1": {
"client": OpenRouter(
api_key=os.environ.get("OPENROUTER_API_KEY"),
),
"model": "deepseek/deepseek-r1",
"max_context": 30000,
"tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o")
},
}
def get_model_type():
"""
Get the model type from Streamlit sidebar with model names
"""
# Get the available model types from the MODEL_CONFIG keys
available_models = list(MODEL_CONFIG.keys())
# Create a list of display labels with just the model names
model_display_labels = [
MODEL_CONFIG[model_type]['model']
for model_type in available_models
]
# Add a sidebar selection for model name
selected_model_name = st.sidebar.selectbox(
"Select AI Model",
model_display_labels,
index=0
)
# Find the corresponding model type for the selected model name
selected_model_type = next(
model_type for model_type in available_models
if MODEL_CONFIG[model_type]['model'] == selected_model_name
)
return selected_model_type
# In the main application flow, replace the previous modelType assignment
modelType = get_model_type()
client = MODEL_CONFIG[modelType]["client"]
MODEL = MODEL_CONFIG[modelType]["model"]
TOOLS_MODEL = MODEL_CONFIG[modelType].get("tools_model") or MODEL
MAX_CONTEXT = MODEL_CONFIG[modelType]["max_context"]
tokenizer = MODEL_CONFIG[modelType]["tokenizer"]
isClaudeModel = modelType == "CLAUDE"
isDeepSeekModel = modelType.startswith("DEEPSEEK")
def __countTokens(text):
text = str(text)
tokens = tokenizer.encode(text, add_special_tokens=False)
return len(tokens)
# Initialize session state variables
if "ipAddress" not in st.session_state:
st.session_state.ipAddress = st.context.headers.get("x-forwarded-for")
if "connection_string" not in st.session_state:
st.session_state.connection_string = None
if "selected_table" not in st.session_state:
st.session_state.selected_table = None
if "table_schema" not in st.session_state:
st.session_state.table_schema = None
if "sample_data" not in st.session_state:
st.session_state.sample_data = None
if "engine" not in st.session_state:
st.session_state.engine = None
def connect_to_db(connection_string):
try:
engine = create_engine(connection_string)
# Test the connection
with engine.connect():
pass
st.session_state.engine = engine
return True
except Exception as e:
st.error(f"Failed to connect to database: {str(e)}")
return False
def get_table_schema(table_name):
if not st.session_state.engine:
return None
inspector = inspect(st.session_state.engine)
columns = inspector.get_columns(table_name)
return {col['name']: str(col['type']) for col in columns}
def get_sample_data(table_name):
if not st.session_state.engine:
return None
query = f"SELECT * FROM {table_name} ORDER BY 1 DESC LIMIT 3"
try:
with st.session_state.engine.connect() as conn:
df = pd.read_sql(query, conn)
return df
except Exception as e:
st.error(f"Error fetching sample data: {str(e)}")
return None
def clean_sql_response(response: str) -> str:
"""Extract clean SQL query from a potentially formatted response."""
# If response contains SQL code block, extract it
sql_block_match = re.search(r'```sql\n(.*?)\n```', response, re.DOTALL)
if sql_block_match:
return sql_block_match.group(1).strip()
return response.strip()
def execute_query(query):
if not st.session_state.engine:
return None
try:
start_time = time.time()
with st.spinner("Executing SQL query..."):
with st.session_state.engine.connect() as conn:
df = pd.read_sql(query, conn)
execution_time = time.time() - start_time
pprint(f"[Query Execution] Latency: {execution_time:.2f}s")
return df
except Exception as e:
st.error(f"Error executing query: {str(e)}")
return None
def generate_sql_query(user_query):
prompt = f"""You are a SQL expert. Generate a valid PostgreSQL query based on the following context and user query.
Table Name: {st.session_state.selected_table}
Table Schema:
{json.dumps(st.session_state.table_schema, indent=2)}
Sample Data:
{st.session_state.sample_data.to_markdown(index=False)}
Important:
1. Only return the SQL query, nothing else
2. The query should be valid PostgreSQL syntax
3. Do not include any explanations or comments
4. Make sure to handle NULL values appropriately
5. Use the table name '{st.session_state.selected_table}' in your query
User Query: {user_query}
"""
prompt_tokens = __countTokens(prompt)
pprint(f"\n[{MODEL}] Prompt tokens for SQL generation: {prompt_tokens}")
# Debug prompt in a Streamlit expander for better organization
with st.expander("Debug: Prompt Generation"):
st.write(f"\nUser Query: {user_query}")
st.write("\nFull Prompt:")
st.code(prompt, language="text")
start_time = time.time()
with st.spinner(f"Generating SQL query using {MODEL}..."):
if isClaudeModel:
response = client.messages.create(
model=MODEL,
max_tokens=1000,
messages=[
{"role": "user", "content": prompt},
]
)
raw_response = response.content[0].text
else:
response = client.chat.completions.create(
model=MODEL,
messages=[
{"role": "user", "content": prompt},
]
)
raw_response = response.choices[0].message.content
generation_time = time.time() - start_time
pprint(f"[{MODEL}] Query Generation Latency: {generation_time:.2f}s")
return clean_sql_response(raw_response)
# UI Components
st.title("SQL Query Assistant")
# Database Connection Section
st.header("1. Database Connection")
connection_string = st.text_input(
"Enter PostgreSQL Connection String",
value=st.session_state.connection_string if st.session_state.connection_string else "",
type="password"
)
if connection_string and connection_string != st.session_state.connection_string:
if connect_to_db(connection_string):
st.session_state.connection_string = connection_string
st.success("Successfully connected to database!")
# Table Selection Section
if st.session_state.connection_string:
st.header("2. Table Selection")
inspector = inspect(st.session_state.engine)
tables = inspector.get_table_names()
# Set default index to 'lsq_leads' if present, otherwise 0
default_index = tables.index('lsq_leads') if 'lsq_leads' in tables else 0
selected_table = st.selectbox("Select a table", tables, index=default_index)
# Create containers for schema and data
schema_container = st.container()
data_container = st.container()
# Always load table data if we have a selected table
if selected_table:
# Update session state
if selected_table != st.session_state.selected_table:
st.session_state.selected_table = selected_table
# Always fetch schema and sample data
st.session_state.table_schema = get_table_schema(selected_table)
st.session_state.sample_data = get_sample_data(selected_table)
# Always display schema and sample data if available
with schema_container:
if st.session_state.table_schema:
st.subheader("Table Schema")
# Force immediate rendering with an empty element
st.empty()
st.json(st.session_state.table_schema)
with data_container:
if st.session_state.sample_data is not None:
st.subheader("Sample Data (Last 3 rows)")
# Force immediate rendering with an empty element
st.empty()
st.dataframe(
st.session_state.sample_data,
use_container_width=True,
hide_index=True
)
# Query Input Section
if st.session_state.selected_table:
st.header("3. Query Input")
user_query = st.text_area("Enter your query in plain English")
if st.button("Generate and Execute Query"):
if user_query:
# Generate SQL query
sql_query = generate_sql_query(user_query)
# Display the generated query
st.subheader("Generated SQL Query")
st.code(sql_query, language="sql")
# Execute the query
results = execute_query(sql_query)
if results is not None:
st.subheader("Query Results")
st.dataframe(results)