HusnaManakkot commited on
Commit
8c7fb17
Β·
verified Β·
1 Parent(s): d66d4d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -8
app.py CHANGED
@@ -1,22 +1,19 @@
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
3
  from datasets import load_dataset
4
 
5
  # Load tokenizer and model
6
  tokenizer = AutoTokenizer.from_pretrained("hrshtsharma2012/NL2SQL-Picard-final")
7
  model = AutoModelForSeq2SeqLM.from_pretrained("hrshtsharma2012/NL2SQL-Picard-final")
8
 
9
- # Initialize the pipeline
10
- nl2sql_pipeline = pipeline("text2text-generation", model=model, tokenizer=tokenizer)
11
-
12
  # Load a part of the Spider dataset
13
  spider_dataset = load_dataset("spider", split='train[:5]')
14
 
15
  def generate_sql(query):
16
- results = nl2sql_pipeline(query)
17
- sql_query = results[0]['generated_text']
18
- # Post-process the output to ensure it's a valid SQL query
19
- sql_query = sql_query.replace('<pad>', '').replace('</s>', '').strip()
20
  return sql_query
21
 
22
  # Use examples from the Spider dataset
 
1
  import gradio as gr
2
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  from datasets import load_dataset
4
 
5
  # Load tokenizer and model
6
  tokenizer = AutoTokenizer.from_pretrained("hrshtsharma2012/NL2SQL-Picard-final")
7
  model = AutoModelForSeq2SeqLM.from_pretrained("hrshtsharma2012/NL2SQL-Picard-final")
8
 
 
 
 
9
  # Load a part of the Spider dataset
10
  spider_dataset = load_dataset("spider", split='train[:5]')
11
 
12
  def generate_sql(query):
13
+ input_text = "translate English to SQL: " + query
14
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True)
15
+ outputs = model.generate(**inputs, max_length=512)
16
+ sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
17
  return sql_query
18
 
19
  # Use examples from the Spider dataset