HusnaManakkot commited on
Commit
72e6803
Β·
verified Β·
1 Parent(s): a588039

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -21
app.py CHANGED
@@ -1,37 +1,37 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
3
 
4
- # Load the tokenizer and model
5
- tokenizer = AutoTokenizer.from_pretrained("Salesforce/codet5-base-multi-summarization-sql-en")
6
- model = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/codet5-base-multi-summarization-sql-en")
7
 
8
- def generate_sql(natural_language_query):
9
- # Tokenize the input query
10
- input_ids = tokenizer(natural_language_query, return_tensors="pt").input_ids
11
 
12
- # Generate the SQL query
13
- output_ids = model.generate(input_ids, max_length=512)[0]
14
 
15
- # Decode the generated SQL query
16
- sql_query = tokenizer.decode(output_ids, skip_special_tokens=True)
 
 
 
17
  return sql_query
18
 
19
- # Example questions for the interface
20
- example_questions = [
21
- "What is the average salary of employees?",
22
- "List the names of employees who work in the IT department.",
23
- "Count the number of employees who joined after 2015."
24
- ]
25
 
26
- # Create the Gradio interface
27
  interface = gr.Interface(
28
  fn=generate_sql,
29
  inputs=gr.Textbox(lines=2, placeholder="Enter your natural language query here..."),
30
  outputs="text",
31
  examples=example_questions,
32
- title="NL to SQL with CodeT5",
33
  description="This model converts natural language queries into SQL using the WikiSQL dataset. Try one of the example questions or enter your own!"
34
  )
35
 
36
- # Launch the interface
37
- interface.launch()
 
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
3
+ from datasets import load_dataset
4
 
5
+ # Load tokenizer and model
6
+ tokenizer = AutoTokenizer.from_pretrained("hrshtsharma2012/NL2SQL-Picard-final")
7
+ model = AutoModelForSeq2SeqLM.from_pretrained("hrshtsharma2012/NL2SQL-Picard-final")
8
 
9
+ # Initialize the pipeline
10
+ nl2sql_pipeline = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
 
11
 
12
+ # Load a part of the WikiSQL dataset
13
+ wikisql_dataset = load_dataset("wikisql", split='train[:5]')
14
 
15
+ def generate_sql(query):
16
+ results = nl2sql_pipeline(query)
17
+ sql_query = results[0]['generated_text']
18
+ # Post-process the output to ensure it's a valid SQL query
19
+ sql_query = sql_query.replace('<pad>', '').replace('</s>', '').strip()
20
  return sql_query
21
 
22
+ # Use examples from the WikiSQL dataset
23
+ example_questions = [(question['question'],) for question in wikisql_dataset]
 
 
 
 
24
 
25
+ # Create a Gradio interface
26
  interface = gr.Interface(
27
  fn=generate_sql,
28
  inputs=gr.Textbox(lines=2, placeholder="Enter your natural language query here..."),
29
  outputs="text",
30
  examples=example_questions,
31
+ title="NL to SQL with Picard",
32
  description="This model converts natural language queries into SQL using the WikiSQL dataset. Try one of the example questions or enter your own!"
33
  )
34
 
35
+ # Launch the app
36
+ if __name__ == "__main__":
37
+ interface.launch()