HusnaManakkot commited on
Commit
887c95b
·
verified ·
1 Parent(s): 6e41bb7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -10
app.py CHANGED
@@ -2,20 +2,21 @@ import gradio as gr
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  from datasets import load_dataset
4
 
5
- # Load the WikiSQL dataset
6
- wikisql_dataset = load_dataset("wikisql", split='train') # Load a subset of the dataset
7
 
8
  # Extract schema information from the dataset
9
  table_names = set()
10
  column_names = set()
11
- for item in wikisql_dataset:
12
- table_names.add(item['table']['name'])
13
- for column in item['table']['header']:
14
- column_names.add(column)
 
15
 
16
  # Load tokenizer and model
17
- tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
18
- model = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
19
 
20
  def post_process_sql_query(sql_query):
21
  # Modify the SQL query to match the dataset's schema
@@ -45,8 +46,8 @@ interface = gr.Interface(
45
  fn=generate_sql_from_user_input,
46
  inputs=gr.Textbox(label="Enter your natural language query"),
47
  outputs=gr.Textbox(label="Generated SQL Query"),
48
- title="NL to SQL with T5 using WikiSQL Dataset",
49
- description="This model generates an SQL query for your natural language input based on the WikiSQL dataset."
50
  )
51
 
52
  # Launch the app
 
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  from datasets import load_dataset
4
 
5
+ # Load the Spider dataset
6
+ spider_dataset = load_dataset("spider", split='train') # Load a subset of the dataset
7
 
8
  # Extract schema information from the dataset
9
  table_names = set()
10
  column_names = set()
11
+ for item in spider_dataset:
12
+ for table in item['db']['table_names_original']:
13
+ table_names.add(table)
14
+ for column in item['db']['column_names_original']:
15
+ column_names.add(column[1])
16
 
17
  # Load tokenizer and model
18
+ tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL") # Update this to a model fine-tuned on Spider if available
19
+ model = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL") # Update this to a model fine-tuned on Spider if available
20
 
21
  def post_process_sql_query(sql_query):
22
  # Modify the SQL query to match the dataset's schema
 
46
  fn=generate_sql_from_user_input,
47
  inputs=gr.Textbox(label="Enter your natural language query"),
48
  outputs=gr.Textbox(label="Generated SQL Query"),
49
+ title="NL to SQL with T5 using Spider Dataset",
50
+ description="This model generates an SQL query for your natural language input based on the Spider dataset."
51
  )
52
 
53
  # Launch the app