licesma commited on
Commit
2ea9086
·
1 Parent(s): 59f20cf

Add changes to save new rag dadataframe

Browse files
Files changed (1) hide show
  1. rag_helper.ipynb +13 -3
rag_helper.ipynb CHANGED
@@ -251,6 +251,9 @@
251
  "outputs": [],
252
  "source": [
253
  "def run_evaluation(nba_df):\n",
 
 
 
254
  " for index, row in nba_df.iterrows():\n",
255
  " # Create message with sample query and run model\n",
256
  " message=[{ 'role': 'user', 'content': input_text + row[\"natural_query\"]}]\n",
@@ -259,9 +262,16 @@
259
  "\n",
260
  " # Obtain output\n",
261
  " query_output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)\n",
262
- "\n",
263
- " print(\"Query: \", + row[\"sql_query\"])\n",
264
- " print(\"Response: \",query_output)\n"
 
 
 
 
 
 
 
265
  ]
266
  },
267
  {
 
251
  "outputs": [],
252
  "source": [
253
  "def run_evaluation(nba_df):\n",
254
+ " team_flags = []\n",
255
+ " game_flags = []\n",
256
+ " other_stats_flags =[]\n",
257
  " for index, row in nba_df.iterrows():\n",
258
  " # Create message with sample query and run model\n",
259
  " message=[{ 'role': 'user', 'content': input_text + row[\"natural_query\"]}]\n",
 
262
  "\n",
263
  " # Obtain output\n",
264
  " query_output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)\n",
265
+ " team_flags.append(\"team\" in query_output.lower())\n",
266
+ " game_flags.append(\"game\" in query_output.lower())\n",
267
+ " other_stats_flags.append(\"other_stats\" in query_output.lower())\n",
268
+ " #print(\"Query: \", + row[\"sql_query\"])\n",
269
+ " #print(\"Response: \",query_output)\n",
270
+ " \n",
271
+ " nba_df[\"team_flag\"] = team_flags\n",
272
+ " nba_df[\"game_flag\"] = game_flags\n",
273
+ " nba_df[\"other_stats_flag\"] = other_stats_flags\n",
274
+ " nba_df.to_csv(get_path(\"expanded_dta.tsv\"), sep=\"\\t\", index=False)\n"
275
  ]
276
  },
277
  {