HusnaManakkot commited on
Commit
4b8f9d6
·
verified ·
1 Parent(s): 6c6d2f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -32
app.py CHANGED
@@ -3,52 +3,36 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  from datasets import load_dataset
4
 
5
  # Load the WikiSQL dataset
6
- wikisql_dataset = load_dataset("wikisql", split='train[:100]') # Load a subset of the dataset
7
-
8
- # Extract schema information from the dataset
9
- table_names = set()
10
- column_names = set()
11
- for item in wikisql_dataset:
12
- table_names.add(item['table']['name'])
13
- for column in item['table']['header']:
14
- column_names.add(column)
15
 
16
  # Load tokenizer and model
17
- tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
18
- model = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
19
 
20
- def post_process_sql_query(sql_query):
21
- # Modify the SQL query to match the dataset's schema
22
- # This is just an example and might need to be adapted based on the dataset and model output
23
- for table_name in table_names:
24
- if "TABLE" in sql_query:
25
- sql_query = sql_query.replace("TABLE", table_name)
26
- break # Assuming only one table is referenced in the query
27
- for column_name in column_names:
28
- if "COLUMN" in sql_query:
29
- sql_query = sql_query.replace("COLUMN", column_name, 1)
30
- return sql_query
31
 
32
- def generate_sql_from_user_input(query):
33
- # Generate SQL for the user's query
34
  input_text = "translate English to SQL: " + query
35
  inputs = tokenizer(input_text, return_tensors="pt", padding=True)
 
 
36
  outputs = model.generate(**inputs, max_length=512)
37
  sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
38
-
39
- # Post-process the SQL query to match the dataset's schema
40
- sql_query = post_process_sql_query(sql_query)
41
  return sql_query
42
 
43
  # Create a Gradio interface
44
  interface = gr.Interface(
45
- fn=generate_sql_from_user_input,
46
  inputs=gr.Textbox(label="Enter your natural language query"),
47
  outputs=gr.Textbox(label="Generated SQL Query"),
48
- title="NL to SQL with T5 using WikiSQL Dataset",
49
- description="This model generates an SQL query for your natural language input based on the WikiSQL dataset."
50
  )
51
 
52
  # Launch the app
53
- if __name__ == "__main__":
54
- interface.launch()
 
3
  from datasets import load_dataset
4
 
5
  # Load the WikiSQL dataset
6
+ dataset = load_dataset("wikisql", split='train[:1000]')
 
 
 
 
 
 
 
 
7
 
8
  # Load tokenizer and model
9
+ tokenizer = AutoTokenizer.from_pretrained("t5-base")
10
+ model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")
11
 
12
+ def preprocess_data(dataset):
13
+ # Tokenize the questions and SQL queries
14
+ tokenized_questions = tokenizer(dataset['question'], padding=True, truncation=True, return_tensors="pt")
15
+ tokenized_sql = tokenizer(dataset['sql']['human_readable'], padding=True, truncation=True, return_tensors="pt")
16
+ return tokenized_questions, tokenized_sql
 
 
 
 
 
 
17
 
18
+ def generate_sql(query):
19
+ # Preprocess the input query
20
  input_text = "translate English to SQL: " + query
21
  inputs = tokenizer(input_text, return_tensors="pt", padding=True)
22
+
23
+ # Generate SQL query using the model
24
  outputs = model.generate(**inputs, max_length=512)
25
  sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
 
26
  return sql_query
27
 
28
  # Create a Gradio interface
29
  interface = gr.Interface(
30
+ fn=generate_sql,
31
  inputs=gr.Textbox(label="Enter your natural language query"),
32
  outputs=gr.Textbox(label="Generated SQL Query"),
33
+ title="Natural Language to SQL Prototype",
34
+ description="Enter a natural language query and get the corresponding SQL query."
35
  )
36
 
37
  # Launch the app
38
+ interface.launch()