File size: 10,494 Bytes
850c4bf 1acff14 850c4bf 140471a 850c4bf 140471a e7e63fd 140471a e7e63fd 850c4bf 1acff14 140471a 1acff14 850c4bf 1acff14 850c4bf e7e63fd 850c4bf 1acff14 850c4bf e7e63fd 850c4bf e7e63fd 850c4bf e7e63fd 1acff14 850c4bf 1acff14 850c4bf e7e63fd 850c4bf e7e63fd 850c4bf 1acff14 850c4bf 140471a 850c4bf 1acff14 850c4bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 |
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import re
import sqlparse
# Load model and tokenizer
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained(
"onkolahmet/Qwen2-0.5B-Instruct-SQL-generator",
torch_dtype="auto",
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained("onkolahmet/Qwen2-0.5B-Instruct-SQL-generator")
# # Few-shot examples to include in each prompt
# examples = [
# {
# "question": "Get the names and emails of customers who placed an order in the last 30 days.",
# "sql": "SELECT name, email FROM customers WHERE order_date >= DATE_SUB(CURDATE(), INTERVAL 30 DAY);"
# },
# {
# "question": "Find all employees with a salary greater than 50000.",
# "sql": "SELECT * FROM employees WHERE salary > 50000;"
# },
# {
# "question": "List all product names and their categories where the price is below 50.",
# "sql": "SELECT name, category FROM products WHERE price < 50;"
# },
# {
# "question": "How many users registered in the year 2022?",
# "sql": "SELECT COUNT(*) FROM users WHERE YEAR(registration_date) = 2022;"
# }
# ]
def generate_sql(question, context=None):
# Construct prompt with few-shot examples and context if available
prompt = "Translate natural language questions to SQL queries.\n\n"
# Add table context if available
if context and context.strip():
prompt += f"Table Context:\n{context}\n\n"
# # Add few-shot examples
# for ex in examples:
# prompt += f"Q: {ex['question']}\nSQL: {ex['sql']}\n\n"
# Add the current question
prompt += f"Q: {question}\nSQL:"
# Tokenize and generate
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# Generate SQL query
outputs = model.generate(
inputs.input_ids,
max_new_tokens=128,
do_sample=True,
eos_token_id=tokenizer.eos_token_id
)
# Extract and decode only the new generation
sql_query = tokenizer.decode(outputs[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True)
return sql_query.strip()
def clean_sql_output(sql_text):
"""
Clean and deduplicate SQL queries:
1. Remove comments
2. Remove duplicate queries
3. Extract only the most relevant query
4. Format properly
"""
# Remove SQL comments (both single line and multi-line)
sql_text = re.sub(r'--.*?$', '', sql_text, flags=re.MULTILINE)
sql_text = re.sub(r'/\*.*?\*/', '', sql_text, flags=re.DOTALL)
# Remove markdown code block syntax if present
sql_text = re.sub(r'```sql|```', '', sql_text)
# Split into individual queries if multiple exist
if ';' in sql_text:
queries = [q.strip() for q in sql_text.split(';') if q.strip()]
else:
# If no semicolons, try to identify separate queries by SELECT statements
sql_text_cleaned = re.sub(r'\s+', ' ', sql_text)
select_matches = list(re.finditer(r'SELECT\s+', sql_text_cleaned, re.IGNORECASE))
if len(select_matches) > 1:
queries = []
for i in range(len(select_matches)):
start = select_matches[i].start()
end = select_matches[i+1].start() if i < len(select_matches) - 1 else len(sql_text_cleaned)
queries.append(sql_text_cleaned[start:end].strip())
else:
queries = [sql_text]
# Remove empty queries
queries = [q for q in queries if q.strip()]
if not queries:
return ""
# If we have multiple queries, need to deduplicate
if len(queries) > 1:
# Normalize queries for comparison (lowercase, remove extra spaces)
normalized_queries = []
for q in queries:
# Use sqlparse to format and normalize
try:
formatted = sqlparse.format(
q + ('' if q.strip().endswith(';') else ';'),
keyword_case='lower',
identifier_case='lower',
strip_comments=True,
reindent=True
)
normalized_queries.append(formatted)
except:
# If sqlparse fails, just do basic normalization
normalized = re.sub(r'\s+', ' ', q.lower().strip())
normalized_queries.append(normalized)
# Find unique queries
unique_queries = []
unique_normalized = []
for i, norm_q in enumerate(normalized_queries):
if norm_q not in unique_normalized:
unique_normalized.append(norm_q)
unique_queries.append(queries[i])
# Choose the most likely correct query:
# 1. Prefer queries with SELECT
# 2. Prefer longer queries (often more detailed)
# 3. Prefer first query if all else equal
select_queries = [q for q in unique_queries if re.search(r'SELECT\s+', q, re.IGNORECASE)]
if select_queries:
# Choose the longest SELECT query (likely most detailed)
best_query = max(select_queries, key=len)
elif unique_queries:
# If no SELECT queries, choose the longest query
best_query = max(unique_queries, key=len)
else:
# Fallback to the first query
best_query = queries[0]
else:
best_query = queries[0]
# Clean up the chosen query
best_query = best_query.strip()
if not best_query.endswith(';'):
best_query += ';'
# Final formatting to ensure consistent spacing
best_query = re.sub(r'\s+', ' ', best_query)
try:
# Use sqlparse to nicely format the SQL for display
formatted_sql = sqlparse.format(
best_query,
keyword_case='upper',
identifier_case='lower',
reindent=True,
indent_width=2
)
return formatted_sql
except:
return best_query
def process_input(question, table_context):
"""Function to process user input through the model and return formatted results"""
if not question.strip():
return "Please enter a question."
# Generate SQL from the question and context
raw_sql = generate_sql(question, table_context)
# Clean the SQL output
cleaned_sql = clean_sql_output(raw_sql)
if not cleaned_sql:
return "Sorry, I couldn't generate a valid SQL query. Please try rephrasing your question."
return cleaned_sql
# Sample table context examples for the example selector
example_contexts = [
# Example 1
"""
CREATE TABLE customers (
id INT PRIMARY KEY,
name VARCHAR(100),
email VARCHAR(100),
order_date DATE
);
""",
# Example 2
"""
CREATE TABLE products (
id INT PRIMARY KEY,
name VARCHAR(100),
category VARCHAR(50),
price DECIMAL(10,2),
stock_quantity INT
);
""",
# Example 3
"""
CREATE TABLE employees (
id INT PRIMARY KEY,
name VARCHAR(100),
department VARCHAR(50),
salary DECIMAL(10,2),
hire_date DATE
);
CREATE TABLE departments (
id INT PRIMARY KEY,
name VARCHAR(50),
manager_id INT,
budget DECIMAL(15,2)
);
"""
]
# Sample question examples
example_questions = [
"Get the names and emails of customers who placed an order in the last 30 days.",
"Find all products with less than 10 items in stock.",
"List all employees in the Sales department with a salary greater than 50000.",
"What is the total budget for departments with more than 5 employees?",
"Count how many products are in each category where the price is greater than 100."
]
# Create the Gradio interface
with gr.Blocks(title="Text to SQL Converter") as demo:
gr.Markdown("# Text to SQL Query Converter")
gr.Markdown("Enter your question and optional table context to generate an SQL query.")
with gr.Row():
with gr.Column():
question_input = gr.Textbox(
label="Your Question",
placeholder="e.g., Find all products with price less than $50",
lines=2
)
table_context = gr.Textbox(
label="Table Context (Optional)",
placeholder="Enter your database schema or table definitions here...",
lines=10
)
submit_btn = gr.Button("Generate SQL Query")
with gr.Column():
sql_output = gr.Code(
label="Generated SQL Query",
language="sql",
lines=12
)
# Examples section
gr.Markdown("### Try some examples")
example_selector = gr.Examples(
examples=[
["List all products in the 'Electronics' category with price less than $500", example_contexts[1]],
["Find the total number of employees in each department", example_contexts[2]],
["Get customers who placed orders in the last 7 days", example_contexts[0]],
["Count the number of products in each category", example_contexts[1]],
["Find the average salary by department", example_contexts[2]]
],
inputs=[question_input, table_context]
)
# Set up the submit button to trigger the process_input function
submit_btn.click(
fn=process_input,
inputs=[question_input, table_context],
outputs=sql_output
)
# Also trigger on pressing Enter in the question input
question_input.submit(
fn=process_input,
inputs=[question_input, table_context],
outputs=sql_output
)
# Add information about the model
gr.Markdown("""
### About
This app uses a fine-tuned language model to convert natural language questions into SQL queries.
- **Model**: [onkolahmet/Qwen2-0.5B-Instruct-SQL-generator](https://huggingface.co/onkolahmet/Qwen2-0.5B-Instruct-SQL-generator)
- **How to use**:
1. Enter your question in natural language
2. If you have specific table schemas, add them in the Table Context field
3. Click "Generate SQL Query" or press Enter
Note: The model works best when table context is provided, but can generate generic SQL queries without it.
""")
# Launch the app
demo.launch() |