|
from datasets import Dataset |
|
from sql_metadata import Parser |
|
|
|
|
|
def format_deepseek_chat(example, tokenizer, input_prompt): |
|
|
|
prompt = f"{input_prompt}{example['natural_query']}\n" |
|
completion = f"Tables:\n{example['tables']}" |
|
|
|
full_text = prompt + completion |
|
tokenized = tokenizer( |
|
full_text, |
|
truncation=True, |
|
padding="max_length", |
|
max_length=3156, |
|
) |
|
|
|
|
|
prompt_len = len(tokenizer(prompt, truncation=True)["input_ids"]) |
|
labels = tokenized["input_ids"][:] |
|
labels[:prompt_len] = [-100] * prompt_len |
|
tokenized["labels"] = labels |
|
|
|
return tokenized |
|
|
|
|
|
def get_tokenized_dataset(nba_df, tokenizer, input_prompt): |
|
natural_query_list = nba_df["natural_query"].tolist() |
|
sql_query_list = nba_df["sql_query"].tolist() |
|
tables = [Parser(sql_query).tables for sql_query in sql_query_list] |
|
|
|
dataset_dict = { |
|
"natural_query": natural_query_list, |
|
"tables": tables, |
|
} |
|
|
|
|
|
dataset = Dataset.from_dict(dataset_dict) |
|
|
|
tokenized_dataset = dataset.map( |
|
lambda x: format_deepseek_chat(x, tokenizer, input_prompt), |
|
remove_columns=["natural_query", "tables"] |
|
) |
|
split = int(0.9 * len(tokenized_dataset)) |
|
train_dataset = tokenized_dataset.select(range(split)) |
|
val_dataset = tokenized_dataset.select(range(split, len(tokenized_dataset))) |
|
return train_dataset, val_dataset |
|
|