licesma commited on
Commit
59f20cf
·
1 Parent(s): e029c99

Add a rag helper notebook

Browse files
rag_helper.ipynb ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Run pre-trained DeepSeek Coder 1.3B Model on Chat-GPT 4o generated dataset"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": 2,
13
+ "metadata": {},
14
+ "outputs": [],
15
+ "source": [
16
+ "import pandas as pd \n",
17
+ "import warnings\n",
18
+ "warnings.filterwarnings(\"ignore\")\n",
19
+ "from transformers import AutoTokenizer, AutoModelForCausalLM\n",
20
+ "import torch\n",
21
+ "import sys\n",
22
+ "import os\n",
23
+ "import sqlite3 as sql\n",
24
+ "from huggingface_hub import snapshot_download"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": 3,
30
+ "metadata": {},
31
+ "outputs": [],
32
+ "source": [
33
+ "is_google_colab=False"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": 4,
39
+ "metadata": {},
40
+ "outputs": [],
41
+ "source": [
42
+ "current_path = \"./\"\n",
43
+ "\n",
44
+ "def get_path(rel_path):\n",
45
+ " return os.path.join(current_path, rel_path)\n",
46
+ "\n",
47
+ "if is_google_colab:\n",
48
+ " hugging_face_path = snapshot_download(\n",
49
+ " repo_id=\"USC-Applied-NLP-Group/SQL-Generation\",\n",
50
+ " repo_type=\"model\", \n",
51
+ " allow_patterns=[\"src/*\", \"train-data/*\", \"deepseek-coder-1.3b-instruct/*\", \"nba-data/*\"], \n",
52
+ " )\n",
53
+ " sys.path.append(hugging_face_path)\n",
54
+ " current_path = hugging_face_path"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": 5,
60
+ "metadata": {},
61
+ "outputs": [],
62
+ "source": [
63
+ "from src.prompts.pre_rag_prompt import input_text"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "markdown",
68
+ "metadata": {},
69
+ "source": [
70
+ "## First load dataset into pandas dataframe"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": 6,
76
+ "metadata": {},
77
+ "outputs": [
78
+ {
79
+ "name": "stdout",
80
+ "output_type": "stream",
81
+ "text": [
82
+ "Total dataset examples: 1044\n",
83
+ "\n",
84
+ "\n",
85
+ "What is the maximum number of team rebounds recorded by the San Antonio Spurs in away games where they committed more than 20 fouls?\n",
86
+ "SELECT MAX(o.team_rebounds_away) FROM game g JOIN other_stats o ON g.game_id = o.game_id WHERE g.team_abbreviation_away = 'SAS' AND g.pf_away > 20 AND g.season_id = '22003';\n",
87
+ "13\n"
88
+ ]
89
+ }
90
+ ],
91
+ "source": [
92
+ "# Load dataset and check length\n",
93
+ "df = pd.read_csv(get_path(\"train-data/sql_train.tsv\"), sep=\"\\t\")\n",
94
+ "print(\"Total dataset examples: \" + str(len(df)))\n",
95
+ "print(\"\\n\")\n",
96
+ "\n",
97
+ "# Test sampling\n",
98
+ "sample = df.sample(n=1)\n",
99
+ "print(sample[\"natural_query\"].values[0])\n",
100
+ "print(sample[\"sql_query\"].values[0])\n",
101
+ "print(sample[\"result\"].values[0])"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "markdown",
106
+ "metadata": {},
107
+ "source": [
108
+ "## Load pre-trained DeepSeek model using transformers and pytorch packages"
109
+ ]
110
+ },
111
+ {
112
+ "cell_type": "code",
113
+ "execution_count": 7,
114
+ "metadata": {},
115
+ "outputs": [],
116
+ "source": [
117
+ "# Set device to cuda if available, otherwise CPU\n",
118
+ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
119
+ "\n",
120
+ "# Load model and tokenizer\n",
121
+ "if is_google_colab:\n",
122
+ " tokenizer = AutoTokenizer.from_pretrained(get_path(\"deepseek-coder-1.3b-instruct\"))\n",
123
+ " model = AutoModelForCausalLM.from_pretrained(get_path(\"deepseek-coder-1.3b-instruct\"), torch_dtype=torch.bfloat16, device_map=device) \n",
124
+ "else:\n",
125
+ " tokenizer = AutoTokenizer.from_pretrained(\"./deepseek-coder-1.3b-instruct\")\n",
126
+ " model = AutoModelForCausalLM.from_pretrained(\"./deepseek-coder-1.3b-instruct\", torch_dtype=torch.bfloat16, device_map=device) \n",
127
+ "model.generation_config.pad_token_id = tokenizer.pad_token_id"
128
+ ]
129
+ },
130
+ {
131
+ "cell_type": "markdown",
132
+ "metadata": {},
133
+ "source": [
134
+ "## Test model performance on a single example"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": 8,
140
+ "metadata": {},
141
+ "outputs": [
142
+ {
143
+ "name": "stdout",
144
+ "output_type": "stream",
145
+ "text": [
146
+ "Response:\n",
147
+ "game, other_stats\n",
148
+ "\n"
149
+ ]
150
+ }
151
+ ],
152
+ "source": [
153
+ "# Create message with sample query and run model\n",
154
+ "message=[{ 'role': 'user', 'content': input_text + sample[\"natural_query\"].values[0]}]\n",
155
+ "inputs = tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors=\"pt\").to(model.device)\n",
156
+ "outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)\n",
157
+ "\n",
158
+ "# Print output\n",
159
+ "query_output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)\n",
160
+ "print(query_output)"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "markdown",
165
+ "metadata": {},
166
+ "source": [
167
+ "# Test sample output on sqlite3 database"
168
+ ]
169
+ },
170
+ {
171
+ "cell_type": "code",
172
+ "execution_count": 9,
173
+ "metadata": {},
174
+ "outputs": [],
175
+ "source": [
176
+ "# Create connection to sqlite3 database\n",
177
+ "connection = sql.connect(get_path('nba-data/nba.sqlite'))\n",
178
+ "cursor = connection.cursor()\n",
179
+ "\n",
180
+ "# Execute query from model output and print result\n",
181
+ "if query_output[0:7] == \"SQLite:\":\n",
182
+ " print(\"cleaned\")\n",
183
+ " query = query_output[7:]\n",
184
+ "elif query_output[0:4] == \"SQL:\":\n",
185
+ " query = query_output[4:]\n",
186
+ "else:\n",
187
+ " query = query_output\n",
188
+ "\n",
189
+ "try:\n",
190
+ " cursor.execute(query)\n",
191
+ " rows = cursor.fetchall()\n",
192
+ " for row in rows:\n",
193
+ " print(row)\n",
194
+ "except:\n",
195
+ " pass"
196
+ ]
197
+ },
198
+ {
199
+ "cell_type": "markdown",
200
+ "metadata": {},
201
+ "source": [
202
+ "## Create function to compare output to ground truth result from examples"
203
+ ]
204
+ },
205
+ {
206
+ "cell_type": "code",
207
+ "execution_count": 12,
208
+ "metadata": {},
209
+ "outputs": [
210
+ {
211
+ "name": "stdout",
212
+ "output_type": "stream",
213
+ "text": [
214
+ "Which team abbreviation belongs to the team based in Phoenix?\n",
215
+ "SELECT abbreviation FROM team WHERE city = 'Phoenix';\n",
216
+ "PHX\n",
217
+ "\"team\"\n",
218
+ "\n"
219
+ ]
220
+ }
221
+ ],
222
+ "source": [
223
+ "# Obtain sample\n",
224
+ "sample = df.sample(n=1)\n",
225
+ "\n",
226
+ "print(sample[\"natural_query\"].values[0])\n",
227
+ "print(sample[\"sql_query\"].values[0])\n",
228
+ "print(sample[\"result\"].values[0])\n",
229
+ "\n",
230
+ "# Create message with sample query and run model\n",
231
+ "message=[{ 'role': 'user', 'content': input_text + sample[\"natural_query\"].values[0]}]\n",
232
+ "inputs = tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors=\"pt\").to(model.device)\n",
233
+ "outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)\n",
234
+ "\n",
235
+ "# Print output\n",
236
+ "query_output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)\n",
237
+ "print(query_output)\n"
238
+ ]
239
+ },
240
+ {
241
+ "cell_type": "markdown",
242
+ "metadata": {},
243
+ "source": [
244
+ "## Create function to evaluate pretrained model on full datasets"
245
+ ]
246
+ },
247
+ {
248
+ "cell_type": "code",
249
+ "execution_count": null,
250
+ "metadata": {},
251
+ "outputs": [],
252
+ "source": [
253
+ "def run_evaluation(nba_df):\n",
254
+ " for index, row in nba_df.iterrows():\n",
255
+ " # Create message with sample query and run model\n",
256
+ " message=[{ 'role': 'user', 'content': input_text + row[\"natural_query\"]}]\n",
257
+ " inputs = tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors=\"pt\").to(model.device)\n",
258
+ " outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)\n",
259
+ "\n",
260
+ " # Obtain output\n",
261
+ " query_output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)\n",
262
+ "\n",
263
+ " print(\"Query: \", + row[\"sql_query\"])\n",
264
+ " print(\"Response: \",query_output)\n"
265
+ ]
266
+ },
267
+ {
268
+ "cell_type": "code",
269
+ "execution_count": null,
270
+ "metadata": {},
271
+ "outputs": [],
272
+ "source": [
273
+ "run_evaluation(df)"
274
+ ]
275
+ },
276
+ {
277
+ "cell_type": "code",
278
+ "execution_count": null,
279
+ "metadata": {},
280
+ "outputs": [],
281
+ "source": []
282
+ }
283
+ ],
284
+ "metadata": {
285
+ "kernelspec": {
286
+ "display_name": "CSCI544",
287
+ "language": "python",
288
+ "name": "python3"
289
+ },
290
+ "language_info": {
291
+ "codemirror_mode": {
292
+ "name": "ipython",
293
+ "version": 3
294
+ },
295
+ "file_extension": ".py",
296
+ "mimetype": "text/x-python",
297
+ "name": "python",
298
+ "nbconvert_exporter": "python",
299
+ "pygments_lexer": "ipython3",
300
+ "version": "3.11.11"
301
+ }
302
+ },
303
+ "nbformat": 4,
304
+ "nbformat_minor": 2
305
+ }
src/prompts/__pycache__/pre_rag_prompt.cpython-311.pyc ADDED
Binary file (4.13 kB). View file