|
import streamlit as st |
|
import os |
|
import pandas as pd |
|
from typing import Literal, TypedDict |
|
from sqlalchemy import create_engine, inspect |
|
import json |
|
from transformers import AutoTokenizer |
|
from utils import pprint |
|
import time |
|
import re |
|
|
|
from openai import OpenAI |
|
import anthropic |
|
from clients.openRouter import OpenRouter |
|
|
|
|
|
from dotenv import load_dotenv |
|
load_dotenv() |
|
|
|
|
|
st.set_page_config( |
|
page_title="SQL Query Assistant", |
|
page_icon="💾", |
|
layout="centered", |
|
initial_sidebar_state="collapsed" |
|
) |
|
|
|
ModelType = Literal["GPT_4o", "GPT_o1", "CLAUDE", "LLAMA", "DEEPSEEK", "DEEPSEEK_R1", "DEEPSEEK_R1_DISTILL"] |
|
ModelConfig = TypedDict("ModelConfig", { |
|
"client": OpenAI | anthropic.Anthropic, |
|
"model": str, |
|
"max_context": int, |
|
"tokenizer": AutoTokenizer |
|
}) |
|
|
|
MODEL_CONFIG: dict[ModelType, ModelConfig] = { |
|
"CLAUDE": { |
|
"client": anthropic.Anthropic(api_key=os.environ.get("ANTHROPIC_API_KEY")), |
|
"model": "claude-3-5-haiku-20241022", |
|
|
|
|
|
"max_context": 40000, |
|
"tokenizer": AutoTokenizer.from_pretrained("Xenova/claude-tokenizer") |
|
}, |
|
"GPT_4o": { |
|
"client": OpenAI(api_key=os.environ.get("OPENAI_API_KEY")), |
|
"model": "gpt-4o", |
|
"max_context": 15000, |
|
"tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o") |
|
}, |
|
|
|
|
|
|
|
|
|
|
|
|
|
"DEEPSEEK": { |
|
"client": OpenRouter( |
|
api_key=os.environ.get("OPENROUTER_API_KEY"), |
|
), |
|
"model": "deepseek/deepseek-chat", |
|
"max_context": 30000, |
|
"tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o") |
|
}, |
|
"DEEPSEEK_R1": { |
|
"client": OpenRouter( |
|
api_key=os.environ.get("OPENROUTER_API_KEY"), |
|
), |
|
"model": "deepseek/deepseek-r1", |
|
"max_context": 30000, |
|
"tokenizer": AutoTokenizer.from_pretrained("Xenova/gpt-4o") |
|
}, |
|
} |
|
|
|
|
|
def get_model_type(): |
|
""" |
|
Get the model type from Streamlit sidebar with model names |
|
""" |
|
|
|
available_models = list(MODEL_CONFIG.keys()) |
|
|
|
|
|
model_display_labels = [ |
|
MODEL_CONFIG[model_type]['model'] |
|
for model_type in available_models |
|
] |
|
|
|
|
|
selected_model_name = st.sidebar.selectbox( |
|
"Select AI Model", |
|
model_display_labels, |
|
index=0 |
|
) |
|
|
|
|
|
selected_model_type = next( |
|
model_type for model_type in available_models |
|
if MODEL_CONFIG[model_type]['model'] == selected_model_name |
|
) |
|
|
|
return selected_model_type |
|
|
|
|
|
|
|
modelType = get_model_type() |
|
|
|
client = MODEL_CONFIG[modelType]["client"] |
|
MODEL = MODEL_CONFIG[modelType]["model"] |
|
TOOLS_MODEL = MODEL_CONFIG[modelType].get("tools_model") or MODEL |
|
MAX_CONTEXT = MODEL_CONFIG[modelType]["max_context"] |
|
tokenizer = MODEL_CONFIG[modelType]["tokenizer"] |
|
|
|
isClaudeModel = modelType == "CLAUDE" |
|
isDeepSeekModel = modelType.startswith("DEEPSEEK") |
|
|
|
|
|
def __countTokens(text): |
|
text = str(text) |
|
tokens = tokenizer.encode(text, add_special_tokens=False) |
|
return len(tokens) |
|
|
|
|
|
|
|
if "ipAddress" not in st.session_state: |
|
st.session_state.ipAddress = st.context.headers.get("x-forwarded-for") |
|
if "connection_string" not in st.session_state: |
|
st.session_state.connection_string = None |
|
if "selected_table" not in st.session_state: |
|
st.session_state.selected_table = None |
|
if "table_schema" not in st.session_state: |
|
st.session_state.table_schema = None |
|
if "sample_data" not in st.session_state: |
|
st.session_state.sample_data = None |
|
if "engine" not in st.session_state: |
|
st.session_state.engine = None |
|
|
|
|
|
def connect_to_db(connection_string): |
|
try: |
|
engine = create_engine(connection_string) |
|
|
|
with engine.connect(): |
|
pass |
|
st.session_state.engine = engine |
|
return True |
|
except Exception as e: |
|
st.error(f"Failed to connect to database: {str(e)}") |
|
return False |
|
|
|
|
|
def get_table_schema(table_name): |
|
if not st.session_state.engine: |
|
return None |
|
|
|
inspector = inspect(st.session_state.engine) |
|
columns = inspector.get_columns(table_name) |
|
return {col['name']: str(col['type']) for col in columns} |
|
|
|
|
|
def get_sample_data(table_name): |
|
if not st.session_state.engine: |
|
return None |
|
|
|
query = f"SELECT * FROM {table_name} ORDER BY 1 DESC LIMIT 3" |
|
try: |
|
with st.session_state.engine.connect() as conn: |
|
df = pd.read_sql(query, conn) |
|
return df |
|
except Exception as e: |
|
st.error(f"Error fetching sample data: {str(e)}") |
|
return None |
|
|
|
|
|
def clean_sql_response(response: str) -> str: |
|
"""Extract clean SQL query from a potentially formatted response.""" |
|
|
|
sql_block_match = re.search(r'```sql\n(.*?)\n```', response, re.DOTALL) |
|
if sql_block_match: |
|
return sql_block_match.group(1).strip() |
|
return response.strip() |
|
|
|
|
|
def execute_query(query): |
|
if not st.session_state.engine: |
|
return None |
|
|
|
try: |
|
start_time = time.time() |
|
with st.spinner("Executing SQL query..."): |
|
with st.session_state.engine.connect() as conn: |
|
df = pd.read_sql(query, conn) |
|
execution_time = time.time() - start_time |
|
pprint(f"[Query Execution] Latency: {execution_time:.2f}s") |
|
return df |
|
except Exception as e: |
|
st.error(f"Error executing query: {str(e)}") |
|
return None |
|
|
|
|
|
def generate_sql_query(user_query): |
|
prompt = f"""You are a SQL expert. Generate a valid PostgreSQL query based on the following context and user query. |
|
|
|
Table Name: {st.session_state.selected_table} |
|
|
|
Table Schema: |
|
{json.dumps(st.session_state.table_schema, indent=2)} |
|
|
|
Sample Data: |
|
{st.session_state.sample_data.to_markdown(index=False)} |
|
|
|
Important: |
|
1. Only return the SQL query, nothing else |
|
2. The query should be valid PostgreSQL syntax |
|
3. Do not include any explanations or comments |
|
4. Make sure to handle NULL values appropriately |
|
5. Use the table name '{st.session_state.selected_table}' in your query |
|
|
|
User Query: {user_query} |
|
""" |
|
|
|
prompt_tokens = __countTokens(prompt) |
|
pprint(f"\n[{MODEL}] Prompt tokens for SQL generation: {prompt_tokens}") |
|
|
|
|
|
with st.expander("Debug: Prompt Generation"): |
|
st.write(f"\nUser Query: {user_query}") |
|
st.write("\nFull Prompt:") |
|
st.code(prompt, language="text") |
|
|
|
start_time = time.time() |
|
with st.spinner(f"Generating SQL query using {MODEL}..."): |
|
if isClaudeModel: |
|
response = client.messages.create( |
|
model=MODEL, |
|
max_tokens=1000, |
|
messages=[ |
|
{"role": "user", "content": prompt}, |
|
] |
|
) |
|
raw_response = response.content[0].text |
|
else: |
|
response = client.chat.completions.create( |
|
model=MODEL, |
|
messages=[ |
|
{"role": "user", "content": prompt}, |
|
] |
|
) |
|
raw_response = response.choices[0].message.content |
|
|
|
generation_time = time.time() - start_time |
|
pprint(f"[{MODEL}] Query Generation Latency: {generation_time:.2f}s") |
|
|
|
return clean_sql_response(raw_response) |
|
|
|
|
|
|
|
st.title("SQL Query Assistant") |
|
|
|
|
|
st.header("1. Database Connection") |
|
connection_string = st.text_input( |
|
"Enter PostgreSQL Connection String", |
|
value=st.session_state.connection_string if st.session_state.connection_string else "", |
|
type="password" |
|
) |
|
|
|
if connection_string and connection_string != st.session_state.connection_string: |
|
if connect_to_db(connection_string): |
|
st.session_state.connection_string = connection_string |
|
st.success("Successfully connected to database!") |
|
|
|
|
|
if st.session_state.connection_string: |
|
st.header("2. Table Selection") |
|
inspector = inspect(st.session_state.engine) |
|
tables = inspector.get_table_names() |
|
|
|
|
|
default_index = tables.index('lsq_leads') if 'lsq_leads' in tables else 0 |
|
selected_table = st.selectbox("Select a table", tables, index=default_index) |
|
|
|
|
|
schema_container = st.container() |
|
data_container = st.container() |
|
|
|
|
|
if selected_table: |
|
|
|
if selected_table != st.session_state.selected_table: |
|
st.session_state.selected_table = selected_table |
|
|
|
|
|
st.session_state.table_schema = get_table_schema(selected_table) |
|
st.session_state.sample_data = get_sample_data(selected_table) |
|
|
|
|
|
with schema_container: |
|
if st.session_state.table_schema: |
|
st.subheader("Table Schema") |
|
|
|
st.empty() |
|
st.json(st.session_state.table_schema) |
|
|
|
with data_container: |
|
if st.session_state.sample_data is not None: |
|
st.subheader("Sample Data (Last 3 rows)") |
|
|
|
st.empty() |
|
st.dataframe( |
|
st.session_state.sample_data, |
|
use_container_width=True, |
|
hide_index=True |
|
) |
|
|
|
|
|
if st.session_state.selected_table: |
|
st.header("3. Query Input") |
|
user_query = st.text_area("Enter your query in plain English") |
|
|
|
if st.button("Generate and Execute Query"): |
|
if user_query: |
|
|
|
sql_query = generate_sql_query(user_query) |
|
|
|
|
|
st.subheader("Generated SQL Query") |
|
st.code(sql_query, language="sql") |
|
|
|
|
|
results = execute_query(sql_query) |
|
if results is not None: |
|
st.subheader("Query Results") |
|
st.dataframe(results) |
|
|
|
|
|
|