|
import streamlit as st |
|
import os |
|
import pandas as pd |
|
from typing import Literal, TypedDict |
|
from sqlalchemy import create_engine, inspect, text |
|
from transformers import AutoTokenizer |
|
from utils import pprint |
|
import time |
|
import re |
|
|
|
from openai import OpenAI |
|
import anthropic |
|
from clients.openRouter import OpenRouter |
|
|
|
|
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
|
|
|
|
st.set_page_config( |
|
page_title="SQL Query Assistant", |
|
page_icon="💾", |
|
layout="centered", |
|
initial_sidebar_state="collapsed" |
|
) |
|
|
|
ModelType = Literal["GPT_4o", "GPT_o1", "CLAUDE", "LLAMA", "DEEPSEEK", "DEEPSEEK_R1", "DEEPSEEK_R1_DISTILL"] |
|
ModelConfig = TypedDict("ModelConfig", { |
|
"client": OpenAI | anthropic.Anthropic, |
|
"model": str, |
|
"max_context": int, |
|
"tokenizer": AutoTokenizer |
|
}) |
|
|
|
MODEL_CONFIG: dict[ModelType, ModelConfig] = { |
|
"CLAUDE_HAIKU": { |
|
"client": anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")), |
|
"model": "claude-3-5-haiku-20241022", |
|
|
|
|
|
"max_context": 40000, |
|
"tokenizer": AutoTokenizer.from_pretrained("Xenova/claude-tokenizer") |
|
}, |
|
"CLAUDE_SONNET": { |
|
"client": anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")), |
|
|
|
|
|
"model": "claude-3-5-sonnet-20240620", |
|
"max_context": 40000, |
|
"tokenizer": AutoTokenizer.from_pretrained("Xenova/claude-tokenizer") |
|
}, |
|
"GPT_4o": { |
|
"client": OpenAI(api_key=os.environ.get("OPENAI_API_KEY")), |
|
"model": "gpt-4o", |
|
"max_context": 15000, |
|
"tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o") |
|
}, |
|
|
|
|
|
|
|
|
|
|
|
|
|
"DEEPSEEK": { |
|
"client": OpenRouter( |
|
api_key=os.environ.get("OPENROUTER_API_KEY"), |
|
), |
|
"model": "deepseek/deepseek-chat", |
|
"max_context": 30000, |
|
"tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o") |
|
}, |
|
"DEEPSEEK_R1": { |
|
"client": OpenRouter( |
|
api_key=os.environ.get("OPENROUTER_API_KEY"), |
|
), |
|
"model": "deepseek/deepseek-r1", |
|
"max_context": 30000, |
|
"tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o") |
|
}, |
|
} |
|
|
|
|
|
def get_model_type(): |
|
""" |
|
Get the model type from Streamlit sidebar with model names |
|
""" |
|
|
|
available_models = list(MODEL_CONFIG.keys()) |
|
|
|
|
|
model_display_labels = [ |
|
MODEL_CONFIG[model_type]['model'] |
|
for model_type in available_models |
|
] |
|
|
|
|
|
selected_model_name = st.sidebar.selectbox( |
|
"Select AI Model", |
|
model_display_labels, |
|
index=0 |
|
) |
|
|
|
|
|
selected_model_type = next( |
|
model_type for model_type in available_models |
|
if MODEL_CONFIG[model_type]['model'] == selected_model_name |
|
) |
|
|
|
return selected_model_type |
|
|
|
|
|
|
|
modelType = get_model_type() |
|
|
|
client = MODEL_CONFIG[modelType]["client"] |
|
MODEL = MODEL_CONFIG[modelType]["model"] |
|
TOOLS_MODEL = MODEL_CONFIG[modelType].get("tools_model") or MODEL |
|
MAX_CONTEXT = MODEL_CONFIG[modelType]["max_context"] |
|
tokenizer = MODEL_CONFIG[modelType]["tokenizer"] |
|
|
|
isClaudeModel = modelType.startswith("CLAUDE") |
|
isDeepSeekModel = modelType.startswith("DEEPSEEK") |
|
|
|
|
|
def __countTokens(text): |
|
text = str(text) |
|
tokens = tokenizer.encode(text, add_special_tokens=False) |
|
return len(tokens) |
|
|
|
|
|
|
|
if "ipAddress" not in st.session_state: |
|
st.session_state.ipAddress = st.context.headers.get("x-forwarded-for") |
|
if "connection_string" not in st.session_state: |
|
st.session_state.connection_string = None |
|
if "selected_table" not in st.session_state: |
|
st.session_state.selected_table = None |
|
if "table_schema" not in st.session_state: |
|
st.session_state.table_schema = None |
|
if "sample_data" not in st.session_state: |
|
st.session_state.sample_data = None |
|
if "engine" not in st.session_state: |
|
st.session_state.engine = None |
|
|
|
|
|
def connect_to_db(connection_string): |
|
try: |
|
engine = create_engine(connection_string) |
|
|
|
with engine.connect(): |
|
pass |
|
st.session_state.engine = engine |
|
return True |
|
except Exception as e: |
|
st.error(f"Failed to connect to database: {str(e)}") |
|
return False |
|
|
|
|
|
def get_table_schema(table_name): |
|
if not st.session_state.engine: |
|
return None |
|
|
|
inspector = inspect(st.session_state.engine) |
|
columns = inspector.get_columns(table_name) |
|
schema = {col['name']: str(col['type']) for col in columns} |
|
|
|
|
|
table_comment_query = """ |
|
SELECT obj_description(c.oid) as table_comment |
|
FROM pg_class c |
|
JOIN pg_namespace n ON n.oid = c.relnamespace |
|
WHERE c.relname = :table_name |
|
AND n.nspname = 'public' |
|
""" |
|
|
|
|
|
column_comments_query = """ |
|
SELECT |
|
cols.column_name, |
|
( |
|
SELECT pg_catalog.col_description(c.oid, cols.ordinal_position::int) |
|
FROM pg_catalog.pg_class c |
|
WHERE c.oid = (SELECT ('"' || cols.table_name || '"')::regclass::oid) |
|
AND c.relname = cols.table_name |
|
) as column_comment |
|
FROM information_schema.columns cols |
|
WHERE cols.table_name = :table_name |
|
AND cols.table_schema = 'public' |
|
""" |
|
|
|
try: |
|
with st.session_state.engine.connect() as conn: |
|
|
|
table_comment_result = conn.execute(text(table_comment_query), {"table_name": table_name}).fetchone() |
|
table_comment = table_comment_result[0] if table_comment_result else None |
|
|
|
|
|
column_comments_result = conn.execute(text(column_comments_query), {"table_name": table_name}).fetchall() |
|
column_comments = {row[0]: row[1] for row in column_comments_result} |
|
|
|
|
|
enhanced_schema = { |
|
"table_comment": table_comment, |
|
"columns": { |
|
col_name: { |
|
"type": schema[col_name], |
|
"comment": column_comments.get(col_name) |
|
} |
|
for col_name in schema |
|
} |
|
} |
|
|
|
return enhanced_schema |
|
except Exception as e: |
|
st.error(f"Error fetching schema comments: {str(e)}") |
|
return schema |
|
|
|
|
|
def get_sample_data(table_name): |
|
if not st.session_state.engine: |
|
return pd.DataFrame() |
|
|
|
query = f"SELECT * FROM {table_name} ORDER BY 1 DESC LIMIT 3" |
|
try: |
|
with st.session_state.engine.connect() as conn: |
|
df = pd.read_sql(query, conn) |
|
return df |
|
except Exception as e: |
|
st.error(f"Error fetching sample data for {table_name}: {str(e)}") |
|
return pd.DataFrame() |
|
|
|
|
|
def clean_sql_response(response: str) -> str: |
|
"""Extract clean SQL query from a potentially formatted response.""" |
|
|
|
sql_block_match = re.search(r'```sql\n(.*?)\n```', response, re.DOTALL) |
|
if sql_block_match: |
|
return sql_block_match.group(1).strip() |
|
return response.strip() |
|
|
|
|
|
def is_read_only_query(query: str) -> bool: |
|
"""Check if the query is read-only (SELECT only).""" |
|
|
|
query_upper = query.upper() |
|
|
|
|
|
modification_statements = [ |
|
'INSERT', 'UPDATE', 'DELETE', 'DROP', 'CREATE', 'ALTER', 'TRUNCATE', |
|
'REPLACE', 'MERGE', 'UPSERT', 'GRANT', 'REVOKE' |
|
] |
|
|
|
|
|
return not any(query_upper.strip().startswith(stmt) for stmt in modification_statements) |
|
|
|
|
|
def execute_query(query): |
|
if not st.session_state.engine: |
|
return None |
|
|
|
|
|
if not is_read_only_query(query): |
|
st.error("Error: Only SELECT queries are allowed for security reasons.") |
|
return None |
|
|
|
try: |
|
start_time = time.time() |
|
with st.spinner("Executing SQL query..."): |
|
|
|
with st.session_state.engine.begin() as conn: |
|
|
|
result = conn.execute(text(query)) |
|
|
|
df = pd.DataFrame(result.fetchall(), columns=result.keys()) |
|
execution_time = time.time() - start_time |
|
pprint(f"[Query Execution] Latency: {execution_time:.2f}s") |
|
return df |
|
except Exception as e: |
|
st.error(f"Error executing query: {str(e)}") |
|
return None |
|
|
|
|
|
def generate_sql_query(user_query): |
|
|
|
tables_context = [] |
|
for table_name, table_type in st.session_state.selected_tables.items(): |
|
|
|
schema_info = st.session_state.table_schemas[table_name] |
|
|
|
|
|
schema_md = [f"\n\n### {table_type}: {table_name}"] |
|
|
|
|
|
if schema_info.get("table_comment"): |
|
schema_md.append(f"> {schema_info['table_comment']}") |
|
|
|
|
|
schema_md.append("\n**Columns:**") |
|
for col_name, col_info in schema_info["columns"].items(): |
|
col_type = col_info["type"] |
|
col_comment = col_info.get("comment") |
|
|
|
|
|
if col_comment: |
|
schema_md.append(f"- `{col_name}` ({col_type}) - {col_comment}") |
|
else: |
|
schema_md.append(f"- `{col_name}` ({col_type})") |
|
|
|
|
|
schema_md.append("\n**Sample Data:**") |
|
schema_md.append(st.session_state.sample_data[table_name].to_markdown(index=False)) |
|
|
|
|
|
tables_context.append("\n".join(schema_md)) |
|
|
|
prompt = f"""You are a SQL expert. Generate a valid PostgreSQL query based on the following context and user query. |
|
|
|
<AVAILABLE_OBJECTS> |
|
{chr(10).join(tables_context)} |
|
|
|
Important: |
|
1. Only generate SELECT queries - no INSERT, UPDATE, DELETE, or other data modification statements |
|
2. Only return the SQL query, nothing else |
|
3. The query should be valid PostgreSQL syntax |
|
4. Do not include any explanations or comments |
|
5. Make sure to handle NULL values appropriately |
|
6. If joining tables, use appropriate join conditions based on the schema |
|
7. Use table names with appropriate qualifiers to avoid ambiguity |
|
|
|
User Query: {user_query} |
|
""" |
|
|
|
prompt_tokens = __countTokens(prompt) |
|
print("\n") |
|
pprint(f"[{MODEL}] Prompt tokens for SQL generation: {prompt_tokens}") |
|
|
|
|
|
|
|
if 'localhost' in st.context.headers.get("Origin", ""): |
|
with st.expander("Debug: Prompt Generation"): |
|
st.write(f"\nUser Query: {user_query}") |
|
st.write("\nFull Prompt:") |
|
st.code(prompt, language="text") |
|
|
|
start_time = time.time() |
|
with st.spinner(f"Generating SQL query using {MODEL}..."): |
|
if isClaudeModel: |
|
response = client.messages.create( |
|
model=MODEL, |
|
max_tokens=1000, |
|
messages=[ |
|
{"role": "user", "content": prompt}, |
|
] |
|
) |
|
raw_response = response.content[0].text |
|
else: |
|
response = client.chat.completions.create( |
|
model=MODEL, |
|
messages=[ |
|
{"role": "user", "content": prompt}, |
|
] |
|
) |
|
raw_response = response.choices[0].message.content |
|
|
|
generation_time = time.time() - start_time |
|
pprint(f"[{MODEL}] Query Generation Latency: {generation_time:.2f}s") |
|
|
|
return clean_sql_response(raw_response) |
|
|
|
|
|
|
|
st.title("SQL Query Assistant") |
|
|
|
|
|
st.header("1. Database Connection") |
|
connection_string = st.text_input( |
|
"Enter PostgreSQL Connection String", |
|
value=st.session_state.connection_string if st.session_state.connection_string else "", |
|
type="password" |
|
) |
|
|
|
if connection_string and connection_string != st.session_state.connection_string: |
|
if connect_to_db(connection_string): |
|
st.session_state.connection_string = connection_string |
|
st.success("Successfully connected to database!") |
|
|
|
|
|
if st.session_state.connection_string: |
|
st.header("2. Database Object Selection") |
|
inspector = inspect(st.session_state.engine) |
|
|
|
|
|
tables = inspector.get_table_names() |
|
views = inspector.get_view_names() |
|
|
|
|
|
db_objects = [(table, 'Table') for table in tables] + [(view, 'View') for view in views] |
|
db_objects.sort(key=lambda x: x[0]) |
|
|
|
|
|
object_names = [obj[0] for obj in db_objects] |
|
|
|
|
|
default_selections = ['lsq_leads'] if 'lsq_leads' in object_names else [] |
|
|
|
|
|
selected_objects = st.multiselect( |
|
"Select tables/views", |
|
options=object_names, |
|
default=default_selections, |
|
help="You can select multiple tables/views to query across them" |
|
) |
|
|
|
|
|
if selected_objects: |
|
st.write("Selected objects:") |
|
for obj in selected_objects: |
|
obj_type = next(obj_type for obj_name, obj_type in db_objects if obj_name == obj) |
|
st.write(f"- {obj}: {obj_type}") |
|
|
|
|
|
schema_container = st.container() |
|
data_container = st.container() |
|
|
|
|
|
if selected_objects: |
|
|
|
if not isinstance(st.session_state.get("selected_tables"), dict): |
|
st.session_state.selected_tables = {} |
|
if not isinstance(st.session_state.get("table_schemas"), dict): |
|
st.session_state.table_schemas = {} |
|
if not isinstance(st.session_state.get("sample_data"), dict): |
|
st.session_state.sample_data = {} |
|
|
|
|
|
current_tables = set(selected_objects) |
|
previous_tables = set(st.session_state.selected_tables.keys()) |
|
removed_tables = previous_tables - current_tables |
|
|
|
for table in removed_tables: |
|
if table in st.session_state.selected_tables: |
|
del st.session_state.selected_tables[table] |
|
if table in st.session_state.table_schemas: |
|
del st.session_state.table_schemas[table] |
|
if table in st.session_state.sample_data: |
|
del st.session_state.sample_data[table] |
|
|
|
|
|
for obj in selected_objects: |
|
|
|
st.session_state.selected_tables[obj] = next( |
|
obj_type for obj_name, obj_type in db_objects if obj_name == obj |
|
) |
|
|
|
|
|
schema = get_table_schema(obj) |
|
if schema: |
|
st.session_state.table_schemas[obj] = schema |
|
|
|
|
|
sample_data = get_sample_data(obj) |
|
if not sample_data.empty: |
|
st.session_state.sample_data[obj] = sample_data |
|
|
|
|
|
with schema_container: |
|
st.subheader("Table/View Schemas") |
|
for obj in selected_objects: |
|
if obj in st.session_state.table_schemas: |
|
st.write(f"**{obj} Schema:**") |
|
st.json(st.session_state.table_schemas[obj]) |
|
st.write("---") |
|
else: |
|
st.warning(f"Could not fetch schema for {obj}") |
|
|
|
with data_container: |
|
st.subheader("Sample Data") |
|
for obj in selected_objects: |
|
if obj in st.session_state.sample_data and not st.session_state.sample_data[obj].empty: |
|
st.write(f"**{obj} (Last 3 rows):**") |
|
st.dataframe( |
|
st.session_state.sample_data[obj], |
|
use_container_width=True, |
|
hide_index=True |
|
) |
|
st.write("---") |
|
else: |
|
st.warning(f"No sample data available for {obj}") |
|
|
|
|
|
if st.session_state.get("selected_tables"): |
|
st.header("3. Query Input") |
|
user_query = st.text_area("Enter your query in plain English") |
|
|
|
if st.button("Generate and Execute Query"): |
|
if user_query: |
|
|
|
sql_query = generate_sql_query(user_query) |
|
|
|
|
|
st.subheader("Generated SQL Query") |
|
st.code(sql_query, language="sql") |
|
|
|
|
|
results = execute_query(sql_query) |
|
if results is not None: |
|
st.subheader("Query Results") |
|
st.dataframe(results) |
|
|
|
|
|
|