|
import gradio as gr |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
import torch |
|
import re |
|
import sqlparse |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model = AutoModelForCausalLM.from_pretrained( |
|
"onkolahmet/Qwen2-0.5B-Instruct-SQL-generator", |
|
torch_dtype="auto", |
|
device_map="auto" |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained("onkolahmet/Qwen2-0.5B-Instruct-SQL-generator") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_sql(question, context=None): |
|
|
|
prompt = "Translate natural language questions to SQL queries.\n\n" |
|
|
|
|
|
if context and context.strip(): |
|
prompt += f"Table Context:\n{context}\n\n" |
|
|
|
|
|
|
|
|
|
|
|
|
|
prompt += f"Q: {question}\nSQL:" |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
|
|
|
|
outputs = model.generate( |
|
inputs.input_ids, |
|
max_new_tokens=128, |
|
do_sample=True, |
|
eos_token_id=tokenizer.eos_token_id |
|
) |
|
|
|
|
|
sql_query = tokenizer.decode(outputs[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True) |
|
return sql_query.strip() |
|
|
|
def clean_sql_output(sql_text): |
|
""" |
|
Clean and deduplicate SQL queries: |
|
1. Remove comments |
|
2. Remove duplicate queries |
|
3. Extract only the most relevant query |
|
4. Format properly |
|
""" |
|
|
|
sql_text = re.sub(r'--.*?$', '', sql_text, flags=re.MULTILINE) |
|
sql_text = re.sub(r'/\*.*?\*/', '', sql_text, flags=re.DOTALL) |
|
|
|
|
|
sql_text = re.sub(r'```sql|```', '', sql_text) |
|
|
|
|
|
if ';' in sql_text: |
|
queries = [q.strip() for q in sql_text.split(';') if q.strip()] |
|
else: |
|
|
|
sql_text_cleaned = re.sub(r'\s+', ' ', sql_text) |
|
select_matches = list(re.finditer(r'SELECT\s+', sql_text_cleaned, re.IGNORECASE)) |
|
|
|
if len(select_matches) > 1: |
|
queries = [] |
|
for i in range(len(select_matches)): |
|
start = select_matches[i].start() |
|
end = select_matches[i+1].start() if i < len(select_matches) - 1 else len(sql_text_cleaned) |
|
queries.append(sql_text_cleaned[start:end].strip()) |
|
else: |
|
queries = [sql_text] |
|
|
|
|
|
queries = [q for q in queries if q.strip()] |
|
|
|
if not queries: |
|
return "" |
|
|
|
|
|
if len(queries) > 1: |
|
|
|
normalized_queries = [] |
|
for q in queries: |
|
|
|
try: |
|
formatted = sqlparse.format( |
|
q + ('' if q.strip().endswith(';') else ';'), |
|
keyword_case='lower', |
|
identifier_case='lower', |
|
strip_comments=True, |
|
reindent=True |
|
) |
|
normalized_queries.append(formatted) |
|
except: |
|
|
|
normalized = re.sub(r'\s+', ' ', q.lower().strip()) |
|
normalized_queries.append(normalized) |
|
|
|
|
|
unique_queries = [] |
|
unique_normalized = [] |
|
|
|
for i, norm_q in enumerate(normalized_queries): |
|
if norm_q not in unique_normalized: |
|
unique_normalized.append(norm_q) |
|
unique_queries.append(queries[i]) |
|
|
|
|
|
|
|
|
|
|
|
select_queries = [q for q in unique_queries if re.search(r'SELECT\s+', q, re.IGNORECASE)] |
|
|
|
if select_queries: |
|
|
|
best_query = max(select_queries, key=len) |
|
elif unique_queries: |
|
|
|
best_query = max(unique_queries, key=len) |
|
else: |
|
|
|
best_query = queries[0] |
|
else: |
|
best_query = queries[0] |
|
|
|
|
|
best_query = best_query.strip() |
|
if not best_query.endswith(';'): |
|
best_query += ';' |
|
|
|
|
|
best_query = re.sub(r'\s+', ' ', best_query) |
|
|
|
try: |
|
|
|
formatted_sql = sqlparse.format( |
|
best_query, |
|
keyword_case='upper', |
|
identifier_case='lower', |
|
reindent=True, |
|
indent_width=2 |
|
) |
|
return formatted_sql |
|
except: |
|
return best_query |
|
|
|
def process_input(question, table_context): |
|
"""Function to process user input through the model and return formatted results""" |
|
if not question.strip(): |
|
return "Please enter a question." |
|
|
|
|
|
raw_sql = generate_sql(question, table_context) |
|
|
|
|
|
cleaned_sql = clean_sql_output(raw_sql) |
|
|
|
if not cleaned_sql: |
|
return "Sorry, I couldn't generate a valid SQL query. Please try rephrasing your question." |
|
|
|
return cleaned_sql |
|
|
|
|
|
example_contexts = [ |
|
|
|
""" |
|
CREATE TABLE customers ( |
|
id INT PRIMARY KEY, |
|
name VARCHAR(100), |
|
email VARCHAR(100), |
|
order_date DATE |
|
); |
|
""", |
|
|
|
|
|
""" |
|
CREATE TABLE products ( |
|
id INT PRIMARY KEY, |
|
name VARCHAR(100), |
|
category VARCHAR(50), |
|
price DECIMAL(10,2), |
|
stock_quantity INT |
|
); |
|
""", |
|
|
|
|
|
""" |
|
CREATE TABLE employees ( |
|
id INT PRIMARY KEY, |
|
name VARCHAR(100), |
|
department VARCHAR(50), |
|
salary DECIMAL(10,2), |
|
hire_date DATE |
|
); |
|
|
|
CREATE TABLE departments ( |
|
id INT PRIMARY KEY, |
|
name VARCHAR(50), |
|
manager_id INT, |
|
budget DECIMAL(15,2) |
|
); |
|
""" |
|
] |
|
|
|
|
|
example_questions = [ |
|
"Get the names and emails of customers who placed an order in the last 30 days.", |
|
"Find all products with less than 10 items in stock.", |
|
"List all employees in the Sales department with a salary greater than 50000.", |
|
"What is the total budget for departments with more than 5 employees?", |
|
"Count how many products are in each category where the price is greater than 100." |
|
] |
|
|
|
|
|
with gr.Blocks(title="Text to SQL Converter") as demo: |
|
gr.Markdown("# Text to SQL Query Converter") |
|
gr.Markdown("Enter your question and optional table context to generate an SQL query.") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
question_input = gr.Textbox( |
|
label="Your Question", |
|
placeholder="e.g., Find all products with price less than $50", |
|
lines=2 |
|
) |
|
|
|
table_context = gr.Textbox( |
|
label="Table Context (Optional)", |
|
placeholder="Enter your database schema or table definitions here...", |
|
lines=10 |
|
) |
|
|
|
submit_btn = gr.Button("Generate SQL Query") |
|
|
|
with gr.Column(): |
|
sql_output = gr.Code( |
|
label="Generated SQL Query", |
|
language="sql", |
|
lines=12 |
|
) |
|
|
|
|
|
gr.Markdown("### Try some examples") |
|
|
|
example_selector = gr.Examples( |
|
examples=[ |
|
["List all products in the 'Electronics' category with price less than $500", example_contexts[1]], |
|
["Find the total number of employees in each department", example_contexts[2]], |
|
["Get customers who placed orders in the last 7 days", example_contexts[0]], |
|
["Count the number of products in each category", example_contexts[1]], |
|
["Find the average salary by department", example_contexts[2]] |
|
], |
|
inputs=[question_input, table_context] |
|
) |
|
|
|
|
|
submit_btn.click( |
|
fn=process_input, |
|
inputs=[question_input, table_context], |
|
outputs=sql_output |
|
) |
|
|
|
|
|
question_input.submit( |
|
fn=process_input, |
|
inputs=[question_input, table_context], |
|
outputs=sql_output |
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
### About |
|
This app uses a fine-tuned language model to convert natural language questions into SQL queries. |
|
|
|
- **Model**: [onkolahmet/Qwen2-0.5B-Instruct-SQL-generator](https://huggingface.co/onkolahmet/Qwen2-0.5B-Instruct-SQL-generator) |
|
- **How to use**: |
|
1. Enter your question in natural language |
|
2. If you have specific table schemas, add them in the Table Context field |
|
3. Click "Generate SQL Query" or press Enter |
|
|
|
Note: The model works best when table context is provided, but can generate generic SQL queries without it. |
|
""") |
|
|
|
|
|
demo.launch() |