diff --git "a/train_dynamic_rag.ipynb" "b/train_dynamic_rag.ipynb" new file mode 100644--- /dev/null +++ "b/train_dynamic_rag.ipynb" @@ -0,0 +1,1494 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "a87fe5f3", + "metadata": { + "id": "a87fe5f3" + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import torch\n", + "from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer, BitsAndBytesConfig, EarlyStoppingCallback, PreTrainedTokenizer\n", + "from torch.utils.data import DataLoader\n", + "import sys\n", + "from peft import LoraConfig, get_peft_model, TaskType\n", + "from huggingface_hub import snapshot_download\n", + "import os\n", + "import re\n", + "import contextlib #helps make pip silent\n", + "import sys\n", + "import os\n", + "import numpy as np\n", + "\n", + "with contextlib.redirect_stdout(sys.__stdout__), contextlib.redirect_stderr(sys.__stderr__):\n", + " %pip install datasets\n", + " %pip install sql_metadata\n", + "\"\"\"\"\n", + "with contextlib.redirect_stdout(sys.__stdout__), contextlib.redirect_stderr(sys.__stderr__):\n", + " %pip install datasets\n", + " %pip install sql_metadata\n", + "\"\"\"\n", + "from datasets import Dataset\n", + "from sql_metadata import Parser" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "4ec432b2", + "metadata": { + "id": "4ec432b2" + }, + "outputs": [], + "source": [ + "is_google_colab = True\n", + "use_bnb = False" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "47577a7f", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 170, + "referenced_widgets": [ + "9200f1303f124bddaa6114cdf0f5f878", + "17ddbb74e1764f37b8d34c311fae200c", + "ef732739334b4ac593fd665e01cd83c1", + "949ee3d1a9cd4060864dec5d4283ef2c", + "b98629e053674527aacca899ab7f11a9", + "84cc47dc70864bf3aa7599c06eb13c51", + "5d711bb927024d8d9f9b8bb685d6f388", + "3b80c66e0f384c45ab4187301599fab2", + "db6a23e658a34722a8f22505c6ace7b4", + "7751defbc4534d518d9e923b9019aa8b", + "fe6352bce22a40e7a936e7f90313bd02" + ] + }, + "id": "47577a7f", + "outputId": "999c4e88-3f89-49b1-9e21-abac91703bf3" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.11/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: \n", + "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", + "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", + "You will be able to reuse this secret in all of your notebooks.\n", + "Please note that authentication is recommended but still optional to access public models or datasets.\n", + " warnings.warn(\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Fetching 37 files: 0%| | 0/37 [00:00:2: FutureWarning: DataFrame.applymap has been deprecated. Use DataFrame.map instead.\n", + " df = df.applymap(lambda x: re.sub(r'\\s+', ' ', x) if isinstance(x, str) else x)\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Total dataset examples: 1044\n" + ] + } + ], + "source": [ + "\n", + "df = pd.read_csv(read_path(\"train-data/sql_train.tsv\"), sep='\\t')\n", + "df = df.applymap(lambda x: re.sub(r'\\s+', ' ', x) if isinstance(x, str) else x)\n", + "\n", + "# Display dataset info\n", + "print(f\"Total dataset examples: {len(df)}\")\n", + "\n", + "# Load tokenizer\n", + "model_name = read_path(\"deepseek-coder-1.3b-instruct\")\n", + "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", + "\n", + "# Enable 8-bit quantization for lower memory usage\n", + "bnb_config = None\n", + "if use_bnb:\n", + " bnb_config = BitsAndBytesConfig(\n", + " load_in_8bit=True,\n", + " bnb_8bit_compute_dtype=torch.float16\n", + " )\n", + "\n", + "# Load model with quantization\n", + "#device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "device_name = 'cuda' if torch.cuda.is_available() else 'cpu'\n", + "device = torch.device(device_name)\n", + "model = AutoModelForCausalLM.from_pretrained(\n", + " model_name,\n", + " quantization_config=bnb_config,\n", + " device_map=device\n", + ")\n", + "\n", + "tokenizer.truncation_side = \"left\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "7f8b1acf", + "metadata": { + "id": "7f8b1acf" + }, + "outputs": [], + "source": [ + "natural_query_list = df[\"natural_query\"].tolist()\n", + "sql_query_list = df[\"sql_query\"].tolist()\n", + "tables = [Parser(sql_query).tables for sql_query in sql_query_list]\n", + "\n", + "dataset_dict = {\n", + " \"natural_query\": natural_query_list,\n", + " \"tables\": tables,\n", + "}\n", + "\n", + "# Create HuggingFace Dataset\n", + "dataset = Dataset.from_dict(dataset_dict)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "f385a9df", + "metadata": { + "id": "f385a9df" + }, + "outputs": [], + "source": [ + "\n", + "def format_deepseek_chat(example, tokenizer):\n", + " # Manually build the prompt as one flat string\n", + " prompt = f\"{input_prompt}{example['natural_query']}\\n\"\n", + " completion = f\"Tables:\\n{example['tables']}\"\n", + "\n", + " full_text = prompt + completion\n", + " tokenized = tokenizer(\n", + " full_text,\n", + " truncation=True,\n", + " padding=\"max_length\",\n", + " max_length=3156, # or whatever your model can handle\n", + " )\n", + "\n", + " # Mask out prompt tokens in the labels\n", + " prompt_len = len(tokenizer(prompt, truncation=True)[\"input_ids\"])\n", + " labels = tokenized[\"input_ids\"][:]\n", + " labels[:prompt_len] = [-100] * prompt_len\n", + " tokenized[\"labels\"] = labels\n", + "\n", + " return tokenized\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "43562f78", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 121, + "referenced_widgets": [ + "68ff2fc00bd041e7b79a811e3de1e596", + "4c41e81bcd254df7b1265206a5a6b40b", + "1a8c093fccbb437db6e0390a920f5cc5", + "e11d04a9d22a4229922e3eb4e3eb6466", + "5d89a5574a3d4a8993e6dca78d406d2d", + "dd24270dc07942a6972fbfaf58129989", + "643903cd7a5b4a52a4687ec38eb8c4dc", + "13ae11c314664c44ae18d35cf57a1334", + "e68cfd05ba994a34b93107d2eab82ad3", + "ea283e7e8b234519b881c562b7eb01d3", + "1ec5329ea0434df4b74d0f311e016c3e" + ] + }, + "id": "43562f78", + "outputId": "58e8ce3f-b7cd-4cf6-dfa4-180b4a699cf9" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Map: 0%| | 0/1044 [00:00:21: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.\n", + " trainer = Trainer(\n", + "No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.\n" + ] + } + ], + "source": [ + "training_args = TrainingArguments(\n", + " output_dir=MODEL_DIR,\n", + " eval_strategy=\"epoch\", # Evaluate at the end of each epoch\n", + " save_strategy=\"epoch\", # Save model every epoch\n", + " per_device_train_batch_size=1, # LoRA allows higher batch size\n", + " per_device_eval_batch_size=1,\n", + " gradient_accumulation_steps=16,\n", + " num_train_epochs=10, # Increase if needed\n", + " learning_rate=5e-5, # Higher LR since we're only training LoRA layers\n", + " weight_decay=0.001,\n", + " logging_steps=50, # Print loss every 50 steps\n", + " save_total_limit=2, # Keep last 4 checkpoints\n", + " bf16=True if torch.cuda.is_available() else False,\n", + " push_to_hub=False,\n", + " load_best_model_at_end=True,\n", + " metric_for_best_model=\"eval_loss\",\n", + " greater_is_better=False\n", + ")\n", + "\n", + "# Trainer setup\n", + "trainer = Trainer(\n", + " model=model,\n", + " args=training_args,\n", + " train_dataset=train_dataset,\n", + " eval_dataset=val_dataset,\n", + " tokenizer=tokenizer,\n", + " callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0ff5278", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 214 + }, + "id": "b0ff5278", + "outputId": "07e6446f-c680-4532-caad-d62a7d3edd6d" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m The `run_name` is currently set to the same value as `TrainingArguments.output_dir`. If this was not intended, please specify a different run name by setting the `TrainingArguments.run_name` parameter.\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mlicesma\u001b[0m (\u001b[33mlicesma-usc\u001b[0m) to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "Tracking run with wandb version 0.19.9" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "Run data is saved locally in /content/wandb/run-20250420_174906-5ypbflqe" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "Syncing run /content/drive/MyDrive/sql_gen/dyn_rag_test to Weights & Biases (docs)
" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + " View project at https://wandb.ai/licesma-usc/huggingface" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + " View run at https://wandb.ai/licesma-usc/huggingface/runs/5ypbflqe" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [ 4/580 00:11 < 54:56, 0.17 it/s, Epoch 0.05/10]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
EpochTraining LossValidation Loss

" + ] + }, + "metadata": {} + } + ], + "source": [ + "# Run training\n", + "trainer.train()\n", + "\n", + "# Merge LoRA adapters with the base model before saving\n", + "model = model.merge_and_unload()\n", + "model.save_pretrained(MODEL_DIR)\n", + "tokenizer.save_pretrained(MODEL_DIR)" + ] + }, + { + "cell_type": "code", + "source": [ + "\n", + "# Prepare query with the same prompt\n", + "input_text = \"How many points do the Los Angeles Lakers average at home?\"\n", + "message = [{'role': 'user', 'content': input_prompt + input_text}]\n", + "inputs = tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors=\"pt\").to(model.device)\n", + "\n", + "# Generate Tables\n", + "outputs = model.generate(\n", + " inputs,\n", + " max_new_tokens=256,\n", + ")\n", + "model_output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)\n", + "\n", + "print(\"Generated Tables:\", model_output)" + ], + "metadata": { + "id": "J7qO7FE73i40" + }, + "id": "J7qO7FE73i40", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "import sqlite3 as sql\n", + "\n", + "prompt_length = len(input_prompt)\n", + "\n", + "print(prompt_length)\n", + "\n", + "# Create connection to sqlite3 database\n", + "connection = sql.connect(read_path('nba-data/nba.sqlite'))\n", + "cursor = connection.cursor()\n", + "\n", + "for v in val_dataset:\n", + " full_example = tokenizer.decode(v[\"input_ids\"], skip_special_tokens=True)\n", + " user_prompt = full_example[:prompt_length]\n", + " question, tables = full_example[prompt_length:].split(\"Tables:\\n\")\n", + " print(question)\n", + " print(tables)\n", + " break\n", + "" + ], + "metadata": { + "id": "kwHMVyQa3n89" + }, + "id": "kwHMVyQa3n89", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "def extract_tables_from_string(s):\n", + " keywords = {\"game\", \"team\", \"other_stats\"}\n", + " found = {k for k in keywords if k in s}\n", + " return found" + ], + "metadata": { + "id": "LhiHqAaB9uE4" + }, + "id": "LhiHqAaB9uE4", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [], + "metadata": { + "id": "Kdd8nxWD9txh" + }, + "id": "Kdd8nxWD9txh" + }, + { + "cell_type": "code", + "source": [ + "def compare_table_lists(actual_tables, generated_tables):\n", + " actual_set = extract_tables_from_string(actual_tables)\n", + " generated_set = extract_tables_from_string(generated_tables)\n", + "\n", + " # Check if they match\n", + " return generated_set == actual_set" + ], + "metadata": { + "id": "KjAXaUgp4TfY" + }, + "id": "KjAXaUgp4TfY", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "\n", + "num_sql_matched = 0\n", + "\n", + "first_actual = []\n", + "first_model = []\n", + "print(\"Evaluating...\")\n", + "for v in val_dataset:\n", + " full_example = tokenizer.decode(v[\"input_ids\"], skip_special_tokens=True)\n", + " user_prompt = full_example[:prompt_length]\n", + " question, training_tables = full_example[prompt_length:].split(\"Tables:\\n\")\n", + " #print(question)\n", + " #print(sql_query)\n", + "\n", + " # Obtain model output\n", + " message = [{'role': 'user', 'content': input_prompt + question}]\n", + " inputs = tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors=\"pt\").to(model.device)\n", + "\n", + " # Generate SQL query\n", + " outputs = model.generate(\n", + " inputs,\n", + " max_new_tokens=256,\n", + " pad_token_id=tokenizer.eos_token_id,\n", + " )\n", + " model_output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)\n", + " after_last_colon = model_output.rsplit(\":\", 1)[-1]\n", + " tables_string = after_last_colon.replace('\\n', '').replace('\\r', '')\n", + " #print(\"Training tables:\", training_tables)\n", + " #print(\"Model tables:\", tables_string.split(\" \"))\n", + " first_actual = training_tables\n", + " first_model = tables_string\n", + " result = compare_table_lists(training_tables, tables_string)\n", + " if result:\n", + " num_sql_matched += 1\n", + "\n", + "print(\"Accuracy :\", num_sql_matched/len(val_dataset))\n", + "\n" + ], + "metadata": { + "id": "8h7bpMML6G6v" + }, + "id": "8h7bpMML6G6v", + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "\n", + "num_sql_matched = 0\n", + "\n", + "first_actual = []\n", + "first_model = []\n", + "print(\"Evaluating...\")\n", + "for v in val_dataset:\n", + " full_example = tokenizer.decode(v[\"input_ids\"], skip_special_tokens=True)\n", + " user_prompt = full_example[:prompt_length]\n", + " question, training_tables = full_example[prompt_length:].split(\"Tables:\\n\")\n", + " #print(question)\n", + " #print(sql_query)\n", + "\n", + " # Obtain model output\n", + " message = [{'role': 'user', 'content': input_prompt + question}]\n", + " inputs = tokenizer.apply_chat_template(message, add_generation_prompt=True, return_tensors=\"pt\").to(model.device)\n", + "\n", + " # Generate SQL query\n", + " outputs = model.generate(\n", + " inputs,\n", + " max_new_tokens=256,\n", + " pad_token_id=tokenizer.eos_token_id,\n", + " )\n", + " model_output = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)\n", + " after_last_colon = model_output.rsplit(\":\", 1)[-1]\n", + " tables_string = after_last_colon.replace('\\n', '').replace('\\r', '')\n", + " #print(\"Training tables:\", training_tables)\n", + " #print(\"Model tables:\", tables_string.split(\" \"))\n", + " first_actual = training_tables\n", + " first_model = tables_string\n", + " result = compare_table_lists(training_tables, tables_string)\n", + " if result:\n", + " num_sql_matched += 1\n", + "\n", + "print(\"Accuracy :\", num_sql_matched/len(val_dataset))\n", + "\n" + ], + "metadata": { + "id": "CoJeZ4FoUMp_" + }, + "execution_count": null, + "outputs": [], + "id": "CoJeZ4FoUMp_" + }, + { + "cell_type": "code", + "source": [ + "model = AutoModelForCausalLM.from_pretrained(MODEL_DIR, torch_dtype=torch.bfloat16, device_map=device)\n", + "tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)\n" + ], + "metadata": { + "id": "lNG1joS3T8DN" + }, + "id": "lNG1joS3T8DN", + "execution_count": null, + "outputs": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + }, + "colab": { + "provenance": [], + "gpuType": "A100" + }, + "accelerator": "GPU", + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "9200f1303f124bddaa6114cdf0f5f878": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_17ddbb74e1764f37b8d34c311fae200c", + "IPY_MODEL_ef732739334b4ac593fd665e01cd83c1", + "IPY_MODEL_949ee3d1a9cd4060864dec5d4283ef2c" + ], + "layout": "IPY_MODEL_b98629e053674527aacca899ab7f11a9" + } + }, + "17ddbb74e1764f37b8d34c311fae200c": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_84cc47dc70864bf3aa7599c06eb13c51", + "placeholder": "​", + "style": "IPY_MODEL_5d711bb927024d8d9f9b8bb685d6f388", + "value": "Fetching 37 files: 100%" + } + }, + "ef732739334b4ac593fd665e01cd83c1": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_3b80c66e0f384c45ab4187301599fab2", + "max": 37, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_db6a23e658a34722a8f22505c6ace7b4", + "value": 37 + } + }, + "949ee3d1a9cd4060864dec5d4283ef2c": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_7751defbc4534d518d9e923b9019aa8b", + "placeholder": "​", + "style": "IPY_MODEL_fe6352bce22a40e7a936e7f90313bd02", + "value": " 37/37 [00:00<00:00, 3657.54it/s]" + } + }, + "b98629e053674527aacca899ab7f11a9": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "84cc47dc70864bf3aa7599c06eb13c51": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "5d711bb927024d8d9f9b8bb685d6f388": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "3b80c66e0f384c45ab4187301599fab2": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "db6a23e658a34722a8f22505c6ace7b4": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "7751defbc4534d518d9e923b9019aa8b": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "fe6352bce22a40e7a936e7f90313bd02": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "68ff2fc00bd041e7b79a811e3de1e596": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HBoxModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HBoxView", + "box_style": "", + "children": [ + "IPY_MODEL_4c41e81bcd254df7b1265206a5a6b40b", + "IPY_MODEL_1a8c093fccbb437db6e0390a920f5cc5", + "IPY_MODEL_e11d04a9d22a4229922e3eb4e3eb6466" + ], + "layout": "IPY_MODEL_5d89a5574a3d4a8993e6dca78d406d2d" + } + }, + "4c41e81bcd254df7b1265206a5a6b40b": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_dd24270dc07942a6972fbfaf58129989", + "placeholder": "​", + "style": "IPY_MODEL_643903cd7a5b4a52a4687ec38eb8c4dc", + "value": "Map: 100%" + } + }, + "1a8c093fccbb437db6e0390a920f5cc5": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "FloatProgressModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "ProgressView", + "bar_style": "success", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_13ae11c314664c44ae18d35cf57a1334", + "max": 1044, + "min": 0, + "orientation": "horizontal", + "style": "IPY_MODEL_e68cfd05ba994a34b93107d2eab82ad3", + "value": 1044 + } + }, + "e11d04a9d22a4229922e3eb4e3eb6466": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "model_module_version": "1.5.0", + "state": { + "_dom_classes": [], + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "HTMLModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/controls", + "_view_module_version": "1.5.0", + "_view_name": "HTMLView", + "description": "", + "description_tooltip": null, + "layout": "IPY_MODEL_ea283e7e8b234519b881c562b7eb01d3", + "placeholder": "​", + "style": "IPY_MODEL_1ec5329ea0434df4b74d0f311e016c3e", + "value": " 1044/1044 [00:10<00:00, 43.90 examples/s]" + } + }, + "5d89a5574a3d4a8993e6dca78d406d2d": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "dd24270dc07942a6972fbfaf58129989": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "643903cd7a5b4a52a4687ec38eb8c4dc": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + }, + "13ae11c314664c44ae18d35cf57a1334": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "e68cfd05ba994a34b93107d2eab82ad3": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "ProgressStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "bar_color": null, + "description_width": "" + } + }, + "ea283e7e8b234519b881c562b7eb01d3": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "model_module_version": "1.2.0", + "state": { + "_model_module": "@jupyter-widgets/base", + "_model_module_version": "1.2.0", + "_model_name": "LayoutModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "LayoutView", + "align_content": null, + "align_items": null, + "align_self": null, + "border": null, + "bottom": null, + "display": null, + "flex": null, + "flex_flow": null, + "grid_area": null, + "grid_auto_columns": null, + "grid_auto_flow": null, + "grid_auto_rows": null, + "grid_column": null, + "grid_gap": null, + "grid_row": null, + "grid_template_areas": null, + "grid_template_columns": null, + "grid_template_rows": null, + "height": null, + "justify_content": null, + "justify_items": null, + "left": null, + "margin": null, + "max_height": null, + "max_width": null, + "min_height": null, + "min_width": null, + "object_fit": null, + "object_position": null, + "order": null, + "overflow": null, + "overflow_x": null, + "overflow_y": null, + "padding": null, + "right": null, + "top": null, + "visibility": null, + "width": null + } + }, + "1ec5329ea0434df4b74d0f311e016c3e": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "model_module_version": "1.5.0", + "state": { + "_model_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_model_name": "DescriptionStyleModel", + "_view_count": null, + "_view_module": "@jupyter-widgets/base", + "_view_module_version": "1.2.0", + "_view_name": "StyleView", + "description_width": "" + } + } + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file