Add changes to save new rag dadataframe
Browse files- rag_helper.ipynb +13 -3
rag_helper.ipynb
CHANGED
@@ -251,6 +251,9 @@
|
|
251 |
"outputs": [],
|
252 |
"source": [
|
253 |
"def run_evaluation(nba_df):\n",
|
|
|
|
|
|
|
254 |
" for index, row in nba_df.iterrows():\n",
|
255 |
" # Create message with sample query and run model\n",
|
256 |
" message=[{ 'role': 'user', 'content': input_text + row[\"natural_query\"]}]\n",
|
@@ -259,9 +262,16 @@
|
|
259 |
"\n",
|
260 |
" # Obtain output\n",
|
261 |
" query_output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)\n",
|
262 |
-
"\n",
|
263 |
-
"
|
264 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
265 |
]
|
266 |
},
|
267 |
{
|
|
|
251 |
"outputs": [],
|
252 |
"source": [
|
253 |
"def run_evaluation(nba_df):\n",
|
254 |
+
" team_flags = []\n",
|
255 |
+
" game_flags = []\n",
|
256 |
+
" other_stats_flags =[]\n",
|
257 |
" for index, row in nba_df.iterrows():\n",
|
258 |
" # Create message with sample query and run model\n",
|
259 |
" message=[{ 'role': 'user', 'content': input_text + row[\"natural_query\"]}]\n",
|
|
|
262 |
"\n",
|
263 |
" # Obtain output\n",
|
264 |
" query_output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)\n",
|
265 |
+
" team_flags.append(\"team\" in query_output.lower())\n",
|
266 |
+
" game_flags.append(\"game\" in query_output.lower())\n",
|
267 |
+
" other_stats_flags.append(\"other_stats\" in query_output.lower())\n",
|
268 |
+
" #print(\"Query: \", + row[\"sql_query\"])\n",
|
269 |
+
" #print(\"Response: \",query_output)\n",
|
270 |
+
" \n",
|
271 |
+
" nba_df[\"team_flag\"] = team_flags\n",
|
272 |
+
" nba_df[\"game_flag\"] = game_flags\n",
|
273 |
+
" nba_df[\"other_stats_flag\"] = other_stats_flags\n",
|
274 |
+
" nba_df.to_csv(get_path(\"expanded_dta.tsv\"), sep=\"\\t\", index=False)\n"
|
275 |
]
|
276 |
},
|
277 |
{
|