HusnaManakkot commited on
Commit
7a62d53
Β·
verified Β·
1 Parent(s): b8ee71c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -16
app.py CHANGED
@@ -3,33 +3,45 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  from datasets import load_dataset
4
 
5
  # Load the Spider dataset
6
- spider_dataset = load_dataset("spider", split='train[:1000]')
7
 
8
- # Load tokenizer and model
9
- tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
10
- model = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL")
 
 
 
 
 
11
 
12
- def generate_sql_from_dataset(index):
13
- # Ensure the index is within the range of the dataset
14
- index = int(index) # Convert to integer in case it's passed as a string
15
- if index < 0 or index >= len(spider_dataset):
16
- return "Invalid index. Please enter a number between 0 and {}.".format(len(spider_dataset) - 1), ""
17
 
18
- # Get the natural language query from the dataset
19
- query = spider_dataset[index]['question']
20
  input_text = "translate English to SQL: " + query
21
  inputs = tokenizer(input_text, return_tensors="pt", padding=True)
22
  outputs = model.generate(**inputs, max_length=512)
23
  sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
24
- return query, sql_query
 
 
 
 
 
 
 
 
 
25
 
26
  # Create a Gradio interface
27
  interface = gr.Interface(
28
- fn=generate_sql_from_dataset,
29
- inputs=gr.Number(label="Dataset Index (0-4)"),
30
- outputs=[gr.Textbox(label="Natural Language Query"), gr.Textbox(label="Generated SQL Query")],
31
  title="NL to SQL with T5 using Spider Dataset",
32
- description="This model converts natural language queries from the Spider dataset into SQL. Enter the index of the dataset entry (0-4)!"
33
  )
34
 
35
  # Launch the app
 
3
  from datasets import load_dataset
4
 
5
  # Load the Spider dataset
6
+ spider_dataset = load_dataset("spider", split='train') # Load a subset of the dataset
7
 
8
+ # Extract schema information from the Spider dataset
9
+ table_names = set()
10
+ column_names = set()
11
+ for item in spider_dataset:
12
+ for table in item['db']['table_names_original']:
13
+ table_names.add(table)
14
+ for column in item['db']['column_names_original']:
15
+ column_names.add(column[1])
16
 
17
+ # Load tokenizer and model
18
+ tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL") # Update this to a model fine-tuned on Spider if available
19
+ model = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-base-finetuned-wikiSQL") # Update this to a model fine-tuned on Spider if available
 
 
20
 
21
+ def generate_sql_from_user_input(query):
22
+ # Generate SQL for the user's query
23
  input_text = "translate English to SQL: " + query
24
  inputs = tokenizer(input_text, return_tensors="pt", padding=True)
25
  outputs = model.generate(**inputs, max_length=512)
26
  sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
27
+
28
+ # Post-process the SQL query to match the dataset's schema
29
+ for table_name in table_names:
30
+ if "TABLE" in sql_query:
31
+ sql_query = sql_query.replace("TABLE", table_name)
32
+ break # Assuming only one table is referenced in the query
33
+ for column_name in column_names:
34
+ if "COLUMN" in sql_query:
35
+ sql_query = sql_query.replace("COLUMN", column_name, 1)
36
+ return sql_query
37
 
38
  # Create a Gradio interface
39
  interface = gr.Interface(
40
+ fn=generate_sql_from_user_input,
41
+ inputs=gr.Textbox(label="Enter your natural language query"),
42
+ outputs=gr.Textbox(label="Generated SQL Query"),
43
  title="NL to SQL with T5 using Spider Dataset",
44
+ description="This model generates an SQL query for your natural language input based on the Spider dataset."
45
  )
46
 
47
  # Launch the app