{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "jaosjY4rGRNH" }, "source": [ "# Installing NeMo from source\n", "\n", "\n", "You can run either this notebook locally (if you have all the dependencies and a GPU) or on Google Colab.\n", "\n", "Instructions for setting up Colab are as follows:\n", "1. Open a new Python 3 notebook.\n", "2. Import this notebook from GitHub (File -> Upload Notebook -> \"GITHUB\" tab -> copy/paste GitHub URL)\n", "3. Connect to an instance with a GPU (Runtime -> Change runtime type -> select \"GPU\" for hardware accelerator)\n", "4. Run the cell below to set up dependencies.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "goQzOSflEq27" }, "outputs": [], "source": [ "import os \n", "BRANCH = 'r1.17.0'\n", "!apt-get update && apt-get install -y libsndfile1 ffmpeg\n", "!git clone https://github.com/NVIDIA/NeMo --branch $BRANCH\n", "os.chdir('NeMo')\n", "!./reinstall.sh\n", "os.chdir('..')\n" ] }, { "cell_type": "markdown", "metadata": { "id": "GjQ_z_xQMDIb" }, "source": [ "# Overview\n", "\n", "There are three tasks as part of this tutorial\n", "\n", "1. Intent and Slot Classification using Assistant Dataset and a BERT model\n", "2. Intent Classification using Schema Guided Dialogue Dataset and a GPT2 model\n", "3. Answer Extender using MS Marco NLGen Dataset and a BART model\n", "\n", "Feel free to skip to the task that interests you most after installing NeMo from source." ] }, { "cell_type": "markdown", "metadata": { "id": "AS-zwy8tEq2_" }, "source": [ "# 1. Intent and Slot Classification using Assistant Dataset\n", "\n", "## 1.1 Task Description\n", "\n", "**Joint Intent and Slot classification** - is a task of classifying an Intent and detecting all relevant Slots (Entities)\n", "for this Intent in a query.\n", "For example, in the query: `What is the weather in Santa Clara tomorrow morning?`, we would like to classify the query\n", "as a `weather` Intent, and detect `Santa Clara` as a `location` slot and `tomorrow morning` as a `date_time` slot.\n", "Intents and Slots names are usually task specific and defined as labels in the training data.\n", "This is a fundamental step that is executed in any task-driven Conversational Assistant.\n", "\n", "Our model enables to train and then detect both of these tasks together.\n", "\n", "Note: There is a similar model available at [Joint Intent Slot Classification Colab](https://colab.research.google.com/github/NVIDIA/NeMo/blob/stable/tutorials/nlp/Joint_Intent_and_Slot_Classification.ipynb). However, this model only support BERT style models while the model in this tutorial supports other types of models such as GPT2. " ] }, { "cell_type": "markdown", "metadata": { "id": "FJk_UAyeEq3B" }, "source": [ "\n", "## 1.2 Download Assistant dataset and convert to NeMo format\n", "\n", "This is a virtual assistant interaction data set that can be downloaded from here: https://github.com/xliuhw/NLU-Evaluation-Data.\n", "There are about 10K training and 1K testing queries which cover 64 various Intents and 55 Slots. \n", "\n", "An example is:\n", "\n", "* utterance: what alarms have i set for tomorrow \n", "* intent: alarm_query\n", "* slots: date(tomorrow)\n", "\n", "\n", "Note: While only the assistant dataset is used here, import_dataset.py is also compatible with ATIS and SNIPS" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jjOVdGX2Eq3D" }, "outputs": [], "source": [ "# download and unzip the example dataset from github\n", "!wget https://github.com/xliuhw/NLU-Evaluation-Data/archive/master.zip\n", "!unzip master.zip\n", "# convert the dataset to the NeMo format\n", "!python NeMo/scripts/dataset_processing/nlp/intent_and_slot/import_datasets.py --dataset_name=assistant --source_data_dir=./NLU-Evaluation-Data-master --target_data_dir=./assistant" ] }, { "cell_type": "markdown", "metadata": { "id": "5n81deZsEq3G" }, "source": [ "## 1.3 Training and/or Testing the model\n", "\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "eoYc_8jhEq3G" }, "outputs": [], "source": [ "# model.dataset.data_dir: folder to load data from\n", "# model.dataset.dialogues_example_dir: folder that stores predictions for each sample\n", "!(python NeMo/examples/nlp/dialogue/dialogue.py \\\n", " do_training=True \\\n", " model.dataset.data_dir='./assistant' \\\n", " model.dataset.dialogues_example_dir='./assistant_bert_examples' \\\n", " model.dataset.task='assistant' \\\n", " model.language_model.pretrained_model_name='bert-base-uncased' \\\n", " exp_manager.create_wandb_logger=False)\n" ] }, { "cell_type": "markdown", "metadata": { "id": "GaPmHjayEbg8" }, "source": [ "**Results after 3 epochs**\n", "\n", "Intent report: \n", "```\n", " label precision recall f1 support \n", " alarm_query (label_id: 0) 100.00 94.44 97.14 18\n", " alarm_remove (label_id: 1) 100.00 90.91 95.24 11\n", " alarm_set (label_id: 2) 94.12 94.12 94.12 17\n", " audio_volume_down (label_id: 3) 75.00 42.86 54.55 7\n", " audio_volume_mute (label_id: 4) 100.00 92.86 96.30 14\n", " audio_volume_up (label_id: 5) 72.22 100.00 83.87 13\n", " calendar_query (label_id: 6) 87.50 77.78 82.35 18\n", " calendar_remove (label_id: 7) 94.44 100.00 97.14 17\n", " calendar_set (label_id: 8) 94.44 94.44 94.44 18\n", " cooking_recipe (label_id: 9) 85.71 70.59 77.42 17\n", " datetime_convert (label_id: 10) 88.89 100.00 94.12 8\n", " datetime_query (label_id: 11) 89.47 100.00 94.44 17\n", " email_addcontact (label_id: 12) 80.00 100.00 88.89 8\n", " email_query (label_id: 13) 100.00 83.33 90.91 18\n", " email_querycontact (label_id: 14) 78.95 88.24 83.33 17\n", " email_sendemail (label_id: 15) 94.44 94.44 94.44 18\n", " general_affirm (label_id: 16) 100.00 100.00 100.00 17\n", " general_commandstop (label_id: 17) 100.00 100.00 100.00 18\n", " general_confirm (label_id: 18) 100.00 100.00 100.00 17\n", " general_dontcare (label_id: 19) 100.00 100.00 100.00 18\n", " general_explain (label_id: 20) 100.00 100.00 100.00 17\n", " general_joke (label_id: 21) 91.67 100.00 95.65 11\n", " general_negate (label_id: 22) 100.00 100.00 100.00 18\n", " general_praise (label_id: 23) 100.00 100.00 100.00 17\n", " general_quirky (label_id: 24) 60.00 50.00 54.55 18\n", " general_repeat (label_id: 25) 100.00 100.00 100.00 17\n", " iot_cleaning (label_id: 26) 100.00 100.00 100.00 15\n", " iot_coffee (label_id: 27) 85.71 100.00 92.31 18\n", " iot_hue_lightchange (label_id: 28) 100.00 94.12 96.97 17\n", " iot_hue_lightdim (label_id: 29) 100.00 100.00 100.00 12\n", " iot_hue_lightoff (label_id: 30) 100.00 100.00 100.00 17\n", " iot_hue_lighton (label_id: 31) 100.00 50.00 66.67 4\n", " iot_hue_lightup (label_id: 32) 84.62 91.67 88.00 12\n", " iot_wemo_off (label_id: 33) 100.00 100.00 100.00 9\n", " iot_wemo_on (label_id: 34) 100.00 85.71 92.31 7\n", " lists_createoradd (label_id: 35) 90.00 100.00 94.74 18\n", " lists_query (label_id: 36) 100.00 94.12 96.97 17\n", " lists_remove (label_id: 37) 88.89 88.89 88.89 18\n", " music_likeness (label_id: 38) 100.00 93.75 96.77 16\n", " music_query (label_id: 39) 100.00 100.00 100.00 17\n", " music_settings (label_id: 40) 77.78 100.00 87.50 7\n", " news_query (label_id: 41) 72.73 88.89 80.00 18\n", " play_audiobook (label_id: 42) 100.00 100.00 100.00 17\n", " play_game (label_id: 43) 93.75 83.33 88.24 18\n", " play_music (label_id: 44) 85.00 100.00 91.89 17\n", " play_podcasts (label_id: 45) 100.00 88.89 94.12 18\n", " play_radio (label_id: 46) 84.21 94.12 88.89 17\n", " qa_currency (label_id: 47) 85.00 94.44 89.47 18\n", " qa_definition (label_id: 48) 89.47 100.00 94.44 17\n", " qa_factoid (label_id: 49) 64.00 88.89 74.42 18\n", " qa_maths (label_id: 50) 84.62 84.62 84.62 13\n", " qa_stock (label_id: 51) 87.50 77.78 82.35 18\n", " recommendation_events (label_id: 52) 87.50 82.35 84.85 17\n", " recommendation_locations (label_id: 53) 83.33 83.33 83.33 18\n", " recommendation_movies (label_id: 54) 100.00 60.00 75.00 10\n", " social_post (label_id: 55) 100.00 94.12 96.97 17\n", " social_query (label_id: 56) 100.00 82.35 90.32 17\n", " takeaway_order (label_id: 57) 92.31 70.59 80.00 17\n", " takeaway_query (label_id: 58) 93.75 83.33 88.24 18\n", " transport_query (label_id: 59) 81.25 76.47 78.79 17\n", " transport_taxi (label_id: 60) 100.00 100.00 100.00 16\n", " transport_ticket (label_id: 61) 85.00 94.44 89.47 18\n", " transport_traffic (label_id: 62) 93.75 88.24 90.91 17\n", " weather_query (label_id: 63) 89.47 100.00 94.44 17\n", " -------------------\n", " micro avg 91.16 91.16 91.16 996\n", " macro avg 91.66 90.44 90.48 996\n", " weighted avg 91.72 91.16 91.04 996\n", "```\n", "Slot report: \n", "```\n", " label precision recall f1 support \n", " alarm_type (label_id: 0) 0.00 0.00 0.00 2\n", " app_name (label_id: 1) 0.00 0.00 0.00 1\n", " artist_name (label_id: 2) 17.39 80.00 28.57 5\n", " audiobook_author (label_id: 3) 0.00 0.00 0.00 0\n", " audiobook_name (label_id: 4) 64.52 74.07 68.97 27\n", " business_name (label_id: 5) 81.48 84.62 83.02 52\n", " business_type (label_id: 6) 80.00 80.00 80.00 20\n", " change_amount (label_id: 7) 57.14 66.67 61.54 6\n", " coffee_type (label_id: 8) 100.00 33.33 50.00 3\n", " color_type (label_id: 9) 75.00 92.31 82.76 13\n", " cooking_type (label_id: 10) 0.00 0.00 0.00 1\n", " currency_name (label_id: 11) 100.00 96.43 98.18 28\n", " date (label_id: 12) 87.88 87.22 87.55 133\n", " definition_word (label_id: 13) 85.00 85.00 85.00 20\n", " device_type (label_id: 14) 84.75 76.92 80.65 65\n", " drink_type (label_id: 15) 0.00 0.00 0.00 0\n", " email_address (label_id: 16) 64.29 100.00 78.26 9\n", " email_folder (label_id: 17) 100.00 50.00 66.67 2\n", " event_name (label_id: 18) 80.00 75.00 77.42 64\n", " food_type (label_id: 19) 84.38 77.14 80.60 35\n", " game_name (label_id: 20) 93.55 78.38 85.29 37\n", " game_type (label_id: 21) 0.00 0.00 0.00 0\n", " general_frequency (label_id: 22) 0.00 0.00 0.00 9\n", " house_place (label_id: 23) 80.95 91.89 86.08 37\n", " ingredient (label_id: 24) 0.00 0.00 0.00 1\n", " joke_type (label_id: 25) 100.00 100.00 100.00 5\n", " list_name (label_id: 26) 89.29 69.44 78.12 36\n", " meal_type (label_id: 27) 0.00 0.00 0.00 3\n", " media_type (label_id: 28) 78.95 83.33 81.08 36\n", " movie_name (label_id: 29) 0.00 0.00 0.00 1\n", " movie_type (label_id: 30) 0.00 0.00 0.00 0\n", " music_album (label_id: 31) 0.00 0.00 0.00 0\n", " music_descriptor (label_id: 32) 0.00 0.00 0.00 2\n", " music_genre (label_id: 33) 81.82 90.00 85.71 10\n", " news_topic (label_id: 34) 80.00 30.77 44.44 13\n", " order_type (label_id: 35) 100.00 42.11 59.26 19\n", " person (label_id: 36) 70.79 100.00 82.89 63\n", " personal_info (label_id: 37) 76.19 94.12 84.21 17\n", " place_name (label_id: 38) 82.86 84.47 83.65 103\n", " player_setting (label_id: 39) 75.00 42.86 54.55 7\n", " playlist_name (label_id: 40) 0.00 0.00 0.00 3\n", " podcast_descriptor (label_id: 41) 92.31 54.55 68.57 22\n", " podcast_name (label_id: 42) 66.67 16.67 26.67 12\n", " radio_name (label_id: 43) 94.87 94.87 94.87 39\n", " relation (label_id: 44) 90.91 90.91 90.91 11\n", " song_name (label_id: 45) 100.00 6.67 12.50 15\n", " time (label_id: 46) 77.57 84.69 80.98 98\n", " time_zone (label_id: 47) 44.44 100.00 61.54 4\n", " timeofday (label_id: 48) 86.96 80.00 83.33 25\n", " transport_agency (label_id: 49) 80.00 57.14 66.67 7\n", " transport_descriptor (label_id: 50) 0.00 0.00 0.00 5\n", " transport_name (label_id: 51) 0.00 0.00 0.00 0\n", " transport_type (label_id: 52) 88.89 100.00 94.12 40\n", " weather_descriptor (label_id: 53) 87.50 87.50 87.50 8\n", " O (label_id: 54) 97.07 97.52 97.30 5408\n", " -------------------\n", " micro avg 94.24 94.24 94.24 6582\n", " macro avg 64.87 59.93 59.17 6582\n", " weighted avg 94.23 94.24 93.95 6582\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "-44x5PqyrOeQ" }, "source": [ "## 1.4 (Optional) To train/ test a GPT2 model on the assistant dataset, run the cell below " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "QyqQbpR4rNHT" }, "outputs": [], "source": [ "# model.dataset.data_dir: folder to load data from\n", "# model.dataset.dialogues_example_dir: folder that stores predictions for each sample\n", "# model.tokenizer.special_tokens=\"{pad_token:'<|endoftext|>'}\": gpt2 doesn't specify a pad token, therefore using its EOS token as the pad token\n", "# model.dataset.target_template=with_slots: this perform slot filling with intent classification\n", "!(python NeMo/examples/nlp/dialogue/dialogue.py \\\n", " do_training=True \\\n", " model.dataset.data_dir='./assistant' \\\n", " model.dataset.dialogues_example_dir='./assistant_gpt2_examples' \\\n", " model.dataset.task='assistant' \\\n", " model.language_model.pretrained_model_name='gpt2' \\\n", " trainer.max_epochs=1 \\\n", " model.tokenizer.special_tokens=\"{pad_token:'<|endoftext|>'}\" \\\n", " model.dataset.target_template=with_slots \\\n", " model.dataset.eval_mode=generation \\\n", " exp_manager.create_wandb_logger=False)" ] }, { "cell_type": "markdown", "metadata": { "id": "FbQ-6TVM1yQg" }, "source": [ "**After 1 epoch:**\n", "\n", "More epochs would be helpful\n", "\n", "Intent report:\n", "\n", " ```\n", " label precision recall f1 support \n", " transport query (label_id: 0) 72.73 84.21 78.05 19\n", " weather query (label_id: 1) 94.74 94.74 94.74 19\n", " play game (label_id: 2) 92.86 68.42 78.79 19\n", " qa currency (label_id: 3) 100.00 100.00 100.00 19\n", " qa maths (label_id: 4) 100.00 100.00 100.00 14\n", " iot wemo off (label_id: 5) 75.00 100.00 85.71 9\n", " datetime convert (label_id: 6) 46.67 87.50 60.87 8\n", " email addcontact (label_id: 7) 70.00 87.50 77.78 8\n", " music likeness (label_id: 8) 57.89 61.11 59.46 18\n", " music query (label_id: 9) 78.57 57.89 66.67 19\n", " general negate (label_id: 10) 95.00 100.00 97.44 19\n", " email sendemail (label_id: 11) 92.86 68.42 78.79 19\n", " general affirm (label_id: 12) 95.00 100.00 97.44 19\n", " play audiobook (label_id: 13) 57.69 78.95 66.67 19\n", " general praise (label_id: 14) 100.00 94.74 97.30 19\n", " alarm set (label_id: 15) 85.71 94.74 90.00 19\n", " general explain (label_id: 16) 100.00 89.47 94.44 19\n", " iot wemo on (label_id: 17) 83.33 71.43 76.92 7\n", " cooking recipe (label_id: 18) 90.00 94.74 92.31 19\n", " music settings (label_id: 19) 60.00 42.86 50.00 7\n", " social post (label_id: 20) 84.21 84.21 84.21 19\n", " recommendation events (label_id: 21) 72.73 84.21 78.05 19\n", " audio volume up (label_id: 22) 76.47 100.00 86.67 13\n", " lists remove (label_id: 23) 73.08 100.00 84.44 19\n", " transport ticket (label_id: 24) 94.74 94.74 94.74 19\n", " general joke (label_id: 25) 100.00 100.00 100.00 12\n", " play podcasts (label_id: 26) 94.12 84.21 88.89 19\n", " iot hue lightchange (label_id: 27) 85.71 63.16 72.73 19\n", " audio volume mute (label_id: 28) 84.62 73.33 78.57 15\n", " general dontcare (label_id: 29) 95.00 100.00 97.44 19\n", " qa definition (label_id: 30) 77.27 89.47 82.93 19\n", " email querycontact (label_id: 31) 58.33 73.68 65.12 19\n", " general commandstop (label_id: 32) 100.00 100.00 100.00 19\n", " calendar remove (label_id: 33) 94.44 89.47 91.89 19\n", " news query (label_id: 34) 100.00 57.89 73.33 19\n", " calendar query (label_id: 35) 63.16 63.16 63.16 19\n", " social query (label_id: 36) 88.24 83.33 85.71 18\n", " transport traffic (label_id: 37) 90.48 100.00 95.00 19\n", " transport taxi (label_id: 38) 100.00 94.44 97.14 18\n", " alarm query (label_id: 39) 100.00 94.74 97.30 19\n", " iot hue lightoff (label_id: 40) 88.89 84.21 86.49 19\n", " takeaway order (label_id: 41) 81.25 68.42 74.29 19\n", " iot coffee (label_id: 42) 100.00 94.74 97.30 19\n", " recommendation movies (label_id: 43) 75.00 90.00 81.82 10\n", " iot hue lightup (label_id: 44) 78.57 78.57 78.57 14\n", " email query (label_id: 45) 85.71 94.74 90.00 19\n", " lists createoradd (label_id: 46) 82.35 73.68 77.78 19\n", " play radio (label_id: 47) 84.21 84.21 84.21 19\n", " audio volume down (label_id: 48) 100.00 87.50 93.33 8\n", " general quirky (label_id: 49) 30.00 15.79 20.69 19\n", " play music (label_id: 50) 71.43 52.63 60.61 19\n", " qa stock (label_id: 51) 90.48 100.00 95.00 19\n", " iot cleaning (label_id: 52) 93.33 87.50 90.32 16\n", " iot hue lightdim (label_id: 53) 100.00 100.00 100.00 12\n", " recommendation locations (label_id: 54) 100.00 89.47 94.44 19\n", " general repeat (label_id: 55) 100.00 100.00 100.00 19\n", " takeaway query (label_id: 56) 77.27 89.47 82.93 19\n", " alarm remove (label_id: 57) 100.00 100.00 100.00 11\n", " datetime query (label_id: 58) 75.00 63.16 68.57 19\n", " iot hue lighton (label_id: 59) 60.00 100.00 75.00 3\n", " qa factoid (label_id: 60) 50.00 57.89 53.66 19\n", " calendar set (label_id: 61) 75.00 78.95 76.92 19\n", " general confirm (label_id: 62) 100.00 100.00 100.00 19\n", " lists query (label_id: 63) 66.67 73.68 70.00 19\n", " label_id: 64 0.00 0.00 0.00 0\n", " -------------------\n", " micro avg 83.55 83.55 83.55 1076\n", " macro avg 83.53 83.93 83.01 1076\n", " weighted avg 84.26 83.55 83.30 1076\n", " \n", "```\n", "\n", "```\n", "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n", " Test metric DataLoader 0\n", "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n", " intent_f1 83.55018615722656\n", " intent_precision 83.55018615722656\n", " intent_recall 83.55018615722656\n", " slot_f1 73.99985919756773\n", "slot_joint_goal_accuracy 65.89219330855019\n", " slot_precision 73.85223048327137\n", " slot_recall 74.14807930607186\n", " test_intent_accuracy 83.55018587360595\n", " test_loss_epoch 0.019178826361894608\n", "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "Gd42arYoEq3J" }, "source": [ "# 2. Schema Guided Dialogue (SGD)\n", "\n", "## 2.1 Task Description\n", "---\n", "\n", "SGD is a multi-domain intent classification dataset from Google with close to 100k examples.\n", "\n", "An example is:\n", "\n", "* utterance: I will be eating there at 11:30 am so make the reservation for then.\n", "* intent: ReserveRestaurant\n", "* slots: {\"time\": \"11:30 am\"}\n", "\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "neH8rXwjEq3J" }, "source": [ "## 2.2 Download the dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IgD8eavfJ5pi" }, "outputs": [], "source": [ "!git clone https://github.com/google-research-datasets/dstc8-schema-guided-dialogue.git" ] }, { "cell_type": "markdown", "metadata": { "id": "7G7uPrUpEq3J" }, "source": [ "## 2.3 Training and/or Testing the model\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gqo-rwQlEq3K" }, "outputs": [], "source": [ "# model.dataset.data_dir: folder to load data from\n", "# model.dataset.dialogues_example_dir: folder that stores predictions for each sample\n", "# model.tokenizer.special_tokens=\"{pad_token:'<|endoftext|>'}\": gpt2 doesn't specify a pad token, therefore using its EOS token as the pad token\n", "\n", "!(python NeMo/examples/nlp/dialogue/dialogue.py \\\n", " do_training=True \\\n", " model.dataset.data_dir='./dstc8-schema-guided-dialogue' \\\n", " model.dataset.dialogues_example_dir='./sgd_gpt2_predictions' \\\n", " model.dataset.task='sgd' \\\n", " model.language_model.pretrained_model_name='gpt2' \\\n", " trainer.max_epochs=1 \\\n", " model.tokenizer.special_tokens=\"{pad_token:'<|endoftext|>'}\" \\\n", " exp_manager.create_wandb_logger=False)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kGDlV5HvI2PQ" }, "outputs": [], "source": [ "!ls sgd_gpt2_predictions" ] }, { "cell_type": "markdown", "metadata": { "id": "p8g0f5KDTu9K" }, "source": [ "**After 1 epoch:**\n", "\n", "More epochs would needed to reach convergence.\n", "\n", "\n", "```\n", " label precision recall f1 support \n", " check balance (label_id: 0) 0.00 0.00 0.00 0\n", " find trains (label_id: 1) 80.20 91.95 85.68 348\n", " make payment (label_id: 2) 83.12 28.07 41.97 228\n", " book appointment (label_id: 3) 86.93 87.15 87.04 397\n", " get cars available (label_id: 4) 96.88 90.51 93.58 274\n", " get event dates (label_id: 5) 0.00 0.00 0.00 0\n", " buy bus ticket (label_id: 6) 78.61 91.33 84.49 173\n", " add event (label_id: 7) 0.00 0.00 0.00 0\n", " get alarms (label_id: 8) 58.33 77.78 66.67 45\n", " reserve car (label_id: 9) 83.75 72.43 77.68 185\n", " get events (label_id: 10) 0.00 0.00 0.00 0\n", " reserve roundtrip flights (label_id: 11) 0.00 0.00 0.00 0\n", " lookup music (label_id: 12) 89.83 86.89 88.33 61\n", " book house (label_id: 13) 91.13 92.50 91.81 200\n", " search oneway flight (label_id: 14) 74.77 47.70 58.25 174\n", " buy event tickets (label_id: 15) 72.19 95.31 82.15 128\n", " find apartment (label_id: 16) 0.00 0.00 0.00 0\n", " schedule visit (label_id: 17) 77.27 66.06 71.23 386\n", " play media (label_id: 18) 92.94 86.81 89.77 91\n", " get ride (label_id: 19) 99.41 98.82 99.12 170\n", " reserve oneway flight (label_id: 20) 0.00 0.00 0.00 0\n", " find bus (label_id: 21) 96.64 87.53 91.86 361\n", " find restaurants (label_id: 22) 77.14 91.22 83.59 148\n", " get times for movie (label_id: 23) 0.00 0.00 0.00 0\n", " transfer money (label_id: 24) 0.00 0.00 0.00 0\n", " request payment (label_id: 25) 46.71 63.39 53.79 112\n", " play movie (label_id: 26) 100.00 65.11 78.87 321\n", " search house (label_id: 27) 97.91 91.83 94.77 306\n", " search roundtrip flights (label_id: 28) 67.49 82.41 74.21 199\n", " find provider (label_id: 29) 95.11 90.53 92.77 602\n", " find attractions (label_id: 30) 100.00 89.01 94.19 91\n", " reserve hotel (label_id: 31) 56.75 97.04 71.62 169\n", " lookup song (label_id: 32) 0.00 0.00 0.00 0\n", " add alarm (label_id: 33) 95.68 60.18 73.89 221\n", " find home by area (label_id: 34) 48.95 59.79 53.83 194\n", " get available time (label_id: 35) 0.00 0.00 0.00 0\n", " buy movie tickets (label_id: 36) 100.00 29.39 45.42 473\n", " reserve restaurant (label_id: 37) 95.71 84.80 89.92 342\n", " find movies (label_id: 38) 62.40 97.61 76.14 335\n", " get weather (label_id: 39) 100.00 87.69 93.44 195\n", " search hotel (label_id: 40) 99.35 52.60 68.78 289\n", " find events (label_id: 41) 99.57 82.56 90.27 281\n", " play song (label_id: 42) 0.00 0.00 0.00 0\n", " rent movie (label_id: 43) 0.00 0.00 0.00 0\n", " get train tickets (label_id: 44) 45.83 5.56 9.91 198\n", " none (label_id: 45) 55.77 98.90 71.32 728\n", " label_id: 46 0.00 0.00 0.00 0\n", " -------------------\n", " micro avg 77.23 77.23 77.23 8425\n", " macro avg 82.01 76.68 76.56 8425\n", " weighted avg 83.23 77.23 76.86 8425\n", "\n", "```" ] }, { "cell_type": "markdown", "metadata": { "id": "jUJb-9VLLBXo" }, "source": [ "# 3. MS Marco\n", "\n", "## Task Description\n", "\n", "MS Marco NLGen is a dataset from Microsoft that takes extracted answers and questions and output fluent answers.\n", "\n", "An example is \n", "\n", "\n", "* question: What county is Nine Mile in?\n", "* extracted_answer: Onondaga\n", "* fluent_answer: Nine Mile is in Onondaga county.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "VtXEKG_UQU9u" }, "source": [ "## Download and unzip files" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "b9avsZ1CEq3K" }, "outputs": [], "source": [ "!mkdir ms_marco\n", "os.chdir('ms_marco')\n", "!wget https://msmarco.blob.core.windows.net/msmarco/train_v2.1.json.gz\n", "!wget https://msmarco.blob.core.windows.net/msmarco/dev_v2.1.json.gz\n", "\n", "!gunzip train_v2.1.json.gz\n", "!gunzip dev_v2.1.json.gz\n", "\n", "!python ../NeMo/examples/nlp/dialogue/remove_ms_marco_samples_without_wellFormedAnswers.py --filename train_v2.1.json \n", "!python ../NeMo/examples/nlp/dialogue/remove_ms_marco_samples_without_wellFormedAnswers.py --filename dev_v2.1.json \n", "\n", "os.chdir('..')" ] }, { "cell_type": "markdown", "metadata": { "id": "h7UZ9R8gQTFo" }, "source": [ "## Training and/or Testing the model\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "fwGQCwbvRf2m" }, "outputs": [], "source": [ "# model.dataset.data_dir: folder to load data from\n", "# model.dataset.dialogues_example_dir: folder that stores predictions for each sample\n", "\n", "!(python NeMo/examples/nlp/dialogue/dialogue.py \\\n", " do_training=True \\\n", " model.dataset.dialogues_example_dir='./marco_bart_predictions' \\\n", " model.dataset.data_dir='./ms_marco' \\\n", " model.save_model=True \\\n", " model.dataset.debug_mode=True \\\n", " model.dataset.task='ms_marco' \\\n", " model.language_model.pretrained_model_name='facebook/bart-base' \\\n", " trainer.max_epochs=1 \\\n", " model.dataset.debug_mode=False \\\n", " exp_manager.create_wandb_logger=False)" ] }, { "cell_type": "markdown", "metadata": { "id": "UL7ekAOZ2abi" }, "source": [ "**After 1 epoch:**\n", "\n", "Train more epochs for optimal performance\n", "\n", "```\n", "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n", " Test metric DataLoader 0\n", "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n", " bleu 65.46179962158203\n", " f1 78.24439835896995\n", " precision 81.92473076099847\n", " recall 76.72508929408436\n", " test_accuracy 25.563487607283225\n", " test_loss 0.4419259166606655\n", " test_loss_epoch 0.4420809745788574\n", " test_ppl 1.5557004846779854\n", "────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────\n", "```" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "Dialogue.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.7" } }, "nbformat": 4, "nbformat_minor": 0 }