licesma commited on
Commit
bfad6ce
·
1 Parent(s): c1d6d12

Add support for colab

Browse files
Files changed (1) hide show
  1. test_pretrained.ipynb +52 -135
test_pretrained.ipynb CHANGED
@@ -7,6 +7,56 @@
7
  "# Run pre-trained DeepSeek Coder 1.3B Model on Chat-GPT 4o generated dataset"
8
  ]
9
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  {
11
  "cell_type": "markdown",
12
  "metadata": {},
@@ -33,10 +83,6 @@
33
  }
34
  ],
35
  "source": [
36
- "import pandas as pd \n",
37
- "import warnings\n",
38
- "warnings.filterwarnings(\"ignore\")\n",
39
- "\n",
40
  "# Load dataset and check length\n",
41
  "df = pd.read_csv(\"./train-data/sql_train.tsv\", sep='\\t')\n",
42
  "print(\"Total dataset examples: \" + str(len(df)))\n",
@@ -62,9 +108,6 @@
62
  "metadata": {},
63
  "outputs": [],
64
  "source": [
65
- "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
66
- "import torch\n",
67
- "\n",
68
  "# Set device to cuda if available, otherwise CPU\n",
69
  "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
70
  "\n",
@@ -74,22 +117,6 @@
74
  "model.generation_config.pad_token_id = tokenizer.pad_token_id"
75
  ]
76
  },
77
- {
78
- "cell_type": "markdown",
79
- "metadata": {},
80
- "source": [
81
- "## Create prompt to setup the model for better performance"
82
- ]
83
- },
84
- {
85
- "cell_type": "code",
86
- "execution_count": 19,
87
- "metadata": {},
88
- "outputs": [],
89
- "source": [
90
- "from src.prompts.prompt import input_text"
91
- ]
92
- },
93
  {
94
  "cell_type": "markdown",
95
  "metadata": {},
@@ -144,8 +171,6 @@
144
  }
145
  ],
146
  "source": [
147
- "import sqlite3 as sql\n",
148
- "\n",
149
  "# Create connection to sqlite3 database\n",
150
  "connection = sql.connect('./nba-data/nba.sqlite')\n",
151
  "cursor = connection.cursor()\n",
@@ -193,115 +218,12 @@
193
  }
194
  ],
195
  "source": [
196
- "import math\n",
197
- "from src.evaluation.compare_result import compare_result_two\n",
198
- "\n",
199
- "def compare_result(sample_query, sample_result, query_output):\n",
200
- " # Clean model output to only have the query output\n",
201
- " if query_output[0:7] == \"SQLite:\":\n",
202
- " query = query_output[7:]\n",
203
- " elif query_output[0:4] == \"SQL:\":\n",
204
- " query = query_output[4:]\n",
205
- " else:\n",
206
- " query = query_output\n",
207
- " \n",
208
- " # Try to execute query, if it fails, then this is a failure of the model\n",
209
- " try:\n",
210
- " # Execute query and obtain result\n",
211
- " cursor.execute(query)\n",
212
- " rows = cursor.fetchall()\n",
213
- "\n",
214
- " # Strip all whitespace before comparing queries since there may be differences in spacing, newlines, tabs, etc.\n",
215
- " query = query.replace(\" \", \"\").replace(\"\\n\", \"\").replace(\"\\t\", \"\")\n",
216
- " sample_query = sample_query.replace(\" \", \"\").replace(\"\\n\", \"\").replace(\"\\t\", \"\")\n",
217
- " query_match = (query == sample_query)\n",
218
- "\n",
219
- " # If the queries match, the results clearly also match\n",
220
- " if query_match:\n",
221
- " return True, True, True\n",
222
- "\n",
223
- " # Check if this is a multi-line query\n",
224
- " if \"|\" in sample_result or \"(\" in sample_result:\n",
225
- " #print(rows)\n",
226
- " # Create list of results by stripping separators and splitting on them\n",
227
- " if \"(\" in sample_result:\n",
228
- " sample_result = sample_result.replace(\"(\", \"\").replace(\")\", \"\")\n",
229
- " result_list = sample_result.split(\",\") \n",
230
- " else:\n",
231
- " result_list = sample_result.split(\"|\") \n",
232
- "\n",
233
- " # Strip all results in list\n",
234
- " for i in range(len(result_list)):\n",
235
- " result_list[i] = str(result_list[i]).strip()\n",
236
- " \n",
237
- " # Loop through model result and see if it matches training example\n",
238
- " result = False\n",
239
- " for row in rows:\n",
240
- " for r in row:\n",
241
- " for res in result_list:\n",
242
- " try:\n",
243
- " if math.isclose(float(r), float(res), abs_tol=0.5):\n",
244
- " return True, query_match, True\n",
245
- " except:\n",
246
- " if r in res or res in r:\n",
247
- " return True, query_match, True\n",
248
- " \n",
249
- " # Check if the model returned a sum of examples as opposed to the whole thing\n",
250
- " if len(rows) == 1:\n",
251
- " for r in rows[0]:\n",
252
- " if r == str(len(result_list)):\n",
253
- " return True, query_match, True\n",
254
- " \n",
255
- " return True, query_match, result\n",
256
- " # Else the sample result is a single value or string\n",
257
- " else:\n",
258
- " #print(rows)\n",
259
- " result = False\n",
260
- " # Loop through model result and see if it contains the sample result\n",
261
- " for row in rows:\n",
262
- " for r in row:\n",
263
- " # Check by string\n",
264
- " if str(r) in str(sample_result):\n",
265
- " try:\n",
266
- " if math.isclose(float(r), float(sample_result), abs_tol=0.5):\n",
267
- " return True, query_match, True\n",
268
- " except:\n",
269
- " return True, query_match, True\n",
270
- " # Check by number, using try incase the cast as float fails\n",
271
- " try:\n",
272
- " if math.isclose(float(r), float(sample_result), abs_tol=0.5):\n",
273
- " return True, query_match, True\n",
274
- " except:\n",
275
- " pass\n",
276
- "\n",
277
- " # Check if the model returned a list of examples instead of a total sum (both acceptable)\n",
278
- " try:\n",
279
- " if len(rows) > 1 and len(rows) == int(sample_result):\n",
280
- " return True, query_match, True\n",
281
- " if len(rows[0]) > 1 and rows[0][1] is not None and len(rows[0]) == int(sample_result):\n",
282
- " return True, query_match, True\n",
283
- " except:\n",
284
- " pass\n",
285
- "\n",
286
- " # Compare results and return\n",
287
- " return True, query_match, result\n",
288
- " except:\n",
289
- " return False, False, False\n",
290
- "\n",
291
  "# Obtain sample\n",
292
  "sample = df.sample(n=1)\n",
293
- "sample_dic = {\n",
294
- " \"natural_query\": \"How many home games did the Miami Heat play in the 2021 season?\",\n",
295
- " \"sql_query\": \"SELECT COUNT(*) FROM game WHERE team_name_home = 'Miami Heat' AND season_id = '22021';\",\n",
296
- " \"result\": 41.0\n",
297
- "}\n",
298
  "\n",
299
- "sample = pd.DataFrame([sample_dic])\n",
300
- "\"\"\"\n",
301
  "print(sample[\"natural_query\"].values[0])\n",
302
  "print(sample[\"sql_query\"].values[0])\n",
303
  "print(sample[\"result\"].values[0])\n",
304
- "\"\"\"\n",
305
  "\n",
306
  "# Create message with sample query and run model\n",
307
  "message=[{ 'role': 'user', 'content': input_text + sample[\"natural_query\"].values[0]}]\n",
@@ -312,15 +234,10 @@
312
  "query_output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)\n",
313
  "print(query_output)\n",
314
  "\n",
315
- "result = compare_result(sample[\"sql_query\"].values[0], sample[\"result\"].values[0], query_output)\n",
316
  "print(\"Statement valid? \" + str(result[0]))\n",
317
  "print(\"SQLite matched? \" + str(result[1]))\n",
318
- "print(\"Result matched? \" + str(result[2]))\n",
319
- "\n",
320
- "result_two = compare_result_two(cursor, sample[\"sql_query\"].values[0], sample[\"result\"].values[0], query_output)\n",
321
- "print(\"Statement valid? \" + str(result_two[0]))\n",
322
- "print(\"SQLite matched? \" + str(result_two[1]))\n",
323
- "print(\"Result matched? \" + str(result_two[2]))"
324
  ]
325
  },
326
  {
 
7
  "# Run pre-trained DeepSeek Coder 1.3B Model on Chat-GPT 4o generated dataset"
8
  ]
9
  },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": 22,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "import pandas as pd \n",
17
+ "import warnings\n",
18
+ "warnings.filterwarnings(\"ignore\")\n",
19
+ "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
20
+ "import torch\n",
21
+ "import sys\n",
22
+ "import sqlite3 as sql\n",
23
+ "from huggingface_hub import snapshot_download"
24
+ ]
25
+ },
26
+ {
27
+ "cell_type": "code",
28
+ "execution_count": 23,
29
+ "metadata": {},
30
+ "outputs": [],
31
+ "source": [
32
+ "is_google_colab=False"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": 24,
38
+ "metadata": {},
39
+ "outputs": [],
40
+ "source": [
41
+ "if is_google_colab:\n",
42
+ " hugging_face_path = snapshot_download(\n",
43
+ " repo_id=\"USC-Applied-NLP-Group/SQL-Generation\",\n",
44
+ " repo_type=\"model\", \n",
45
+ " allow_patterns=[\"src/*\"], \n",
46
+ " )\n",
47
+ " sys.path.append(hugging_face_path)"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "code",
52
+ "execution_count": 25,
53
+ "metadata": {},
54
+ "outputs": [],
55
+ "source": [
56
+ "from src.prompts.prompt import input_text\n",
57
+ "from src.evaluation.compare_result import compare_result"
58
+ ]
59
+ },
60
  {
61
  "cell_type": "markdown",
62
  "metadata": {},
 
83
  }
84
  ],
85
  "source": [
 
 
 
 
86
  "# Load dataset and check length\n",
87
  "df = pd.read_csv(\"./train-data/sql_train.tsv\", sep='\\t')\n",
88
  "print(\"Total dataset examples: \" + str(len(df)))\n",
 
108
  "metadata": {},
109
  "outputs": [],
110
  "source": [
 
 
 
111
  "# Set device to cuda if available, otherwise CPU\n",
112
  "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
113
  "\n",
 
117
  "model.generation_config.pad_token_id = tokenizer.pad_token_id"
118
  ]
119
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  {
121
  "cell_type": "markdown",
122
  "metadata": {},
 
171
  }
172
  ],
173
  "source": [
 
 
174
  "# Create connection to sqlite3 database\n",
175
  "connection = sql.connect('./nba-data/nba.sqlite')\n",
176
  "cursor = connection.cursor()\n",
 
218
  }
219
  ],
220
  "source": [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  "# Obtain sample\n",
222
  "sample = df.sample(n=1)\n",
 
 
 
 
 
223
  "\n",
 
 
224
  "print(sample[\"natural_query\"].values[0])\n",
225
  "print(sample[\"sql_query\"].values[0])\n",
226
  "print(sample[\"result\"].values[0])\n",
 
227
  "\n",
228
  "# Create message with sample query and run model\n",
229
  "message=[{ 'role': 'user', 'content': input_text + sample[\"natural_query\"].values[0]}]\n",
 
234
  "query_output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)\n",
235
  "print(query_output)\n",
236
  "\n",
237
+ "result = compare_result(cursor, sample[\"sql_query\"].values[0], sample[\"result\"].values[0], query_output)\n",
238
  "print(\"Statement valid? \" + str(result[0]))\n",
239
  "print(\"SQLite matched? \" + str(result[1]))\n",
240
+ "print(\"Result matched? \" + str(result[2]))"
 
 
 
 
 
241
  ]
242
  },
243
  {