HusnaManakkot commited on
Commit
f525ef3
Β·
verified Β·
1 Parent(s): 3daf9f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -11
app.py CHANGED
@@ -1,29 +1,35 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
3
 
4
  # Load tokenizer and model
5
  tokenizer = AutoTokenizer.from_pretrained("hrshtsharma2012/NL2SQL-Picard-final")
6
  model = AutoModelForSeq2SeqLM.from_pretrained("hrshtsharma2012/NL2SQL-Picard-final")
7
 
 
 
 
 
 
 
8
  def generate_sql(query):
9
- input_text = "translate English to SQL: " + query
10
- inputs = tokenizer(input_text, return_tensors="pt", padding=True)
11
- outputs = model.generate(**inputs, max_length=512)
12
- sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
13
-
14
- # Check if the output is the same as the input
15
- if sql_query.strip().lower() == query.strip().lower():
16
- return "The model did not generate a SQL query. Please try a different input or use a different model."
17
-
18
  return sql_query
19
 
 
 
 
20
  # Create a Gradio interface
21
  interface = gr.Interface(
22
  fn=generate_sql,
23
  inputs=gr.Textbox(lines=2, placeholder="Enter your natural language query here..."),
24
  outputs="text",
 
25
  title="NL to SQL with Picard",
26
- description="This model converts natural language queries into SQL. Enter your query!"
27
  )
28
 
29
  # Launch the app
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
3
+ from datasets import load_dataset
4
 
5
  # Load tokenizer and model
6
  tokenizer = AutoTokenizer.from_pretrained("hrshtsharma2012/NL2SQL-Picard-final")
7
  model = AutoModelForSeq2SeqLM.from_pretrained("hrshtsharma2012/NL2SQL-Picard-final")
8
 
9
+ # Initialize the pipeline
10
+ nl2sql_pipeline = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
11
+
12
+ # Load a part of the Spider dataset
13
+ spider_dataset = load_dataset("spider", split='train[:5]')
14
+
15
  def generate_sql(query):
16
+ results = nl2sql_pipeline(query)
17
+ sql_query = results[0]['generated_text']
18
+ # Post-process the output to ensure it's a valid SQL query
19
+ sql_query = sql_query.replace('<pad>', '').replace('</s>', '').strip()
 
 
 
 
 
20
  return sql_query
21
 
22
+ # Use examples from the Spider dataset
23
+ example_questions = [(question['question'],) for question in spider_dataset]
24
+
25
  # Create a Gradio interface
26
  interface = gr.Interface(
27
  fn=generate_sql,
28
  inputs=gr.Textbox(lines=2, placeholder="Enter your natural language query here..."),
29
  outputs="text",
30
+ examples=example_questions,
31
  title="NL to SQL with Picard",
32
+ description="This model converts natural language queries into SQL using the Spider dataset. Try one of the example questions or enter your own!"
33
  )
34
 
35
  # Launch the app