{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "XO1Y8Apj23Uz" }, "source": [ "# Preparing the environment" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Z9VQRXWGwLQS", "outputId": "00e1b986-7cb0-48e5-f5f1-cae5384237ce" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m485.4/485.4 kB\u001b[0m \u001b[31m9.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m84.0/84.0 kB\u001b[0m \u001b[31m3.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m6.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m143.5/143.5 kB\u001b[0m \u001b[31m8.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.8/194.8 kB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25h Building wheel for rouge_score (setup.py) ... \u001b[?25l\u001b[?25hdone\n" ] } ], "source": [ "!pip install datasets evaluate rouge_score --quiet" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "OpAiX--01Etf" }, "outputs": [], "source": [ "import warnings\n", "\n", "import evaluate\n", "import numpy as np\n", "import torch\n", "from datasets import load_dataset\n", "from peft import LoraConfig, PeftModel, get_peft_model\n", "from transformers import (\n", " AutoModelForSeq2SeqLM,\n", " AutoTokenizer,\n", " DataCollatorForSeq2Seq,\n", " Seq2SeqTrainer,\n", " Seq2SeqTrainingArguments,\n", " logging,\n", ")\n", "\n", "CKPT_PATH = \"facebook/bart-large-cnn\" # Path to the fine-tuned base model\n", "DATASET_PATH = \"ccdv/arxiv-summarization\" # Path to the dataset to be used for fine-tuning\n", "HF_REPO_PATH = \"spolivin/bart-arxiv-lora\" # Path to the repo where LoRA adapters are saved\n", "\n", "warnings.filterwarnings(\"ignore\")\n", "logging.set_verbosity_error()" ] }, { "cell_type": "markdown", "metadata": { "id": "KYyo2SOm26dQ" }, "source": [ "# Loading and preprocessing data" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CDtuyM-WwRwx", "colab": { "base_uri": "https://localhost:8080/", "height": 689, "referenced_widgets": [ "7415870a03b6428094d93967bc9c1bf7", "147c667abe374114babb47b3538ff650", "74d74a270a3040b3bbd7064cc09e1c2d", "153fc063223443a38520be05a5b9d23c", "9491e1872fc04940acca0a0c54bf0107", "3153daf2369847f8921b0501a30ad0f9", "baf687a8771d4e1882f1561b3392be36", "942a70ea07a54764a63e21a225d967c7", "5c1fb52c18524ca99a937c3967804903", "8a5d555401934fa78d47a4070389cdc4", "8e130a2693714e8c9a8a887aa102ca2e", "e1176b3b24a94e7ca8e89925f7a6f2f6", "54a131a63c20490da2de6dfc9950acec", "57b6dfc0003846d19e7d04aaede3284b", "e7c7cd775e7243ddafd47bbeb48b43c0", "6e057db395f642c79133e20fdae6583f", "3a3fda338c5442f18ca772308b134603", "0db9b275dc8142c6bff0f4954556437e", "54225bfa87ff4e94a36573faac0cf094", "5d05b48ad5b44f809b722a6bfa1641d7", "44db00544fe24dcaa1ebcc9797c1a128", "1f7f18f3346a40e7b65e6f41854a1db6", "5af00a93f46b47e2b3c7654d414089f9", "1b593e5c1e9d4b9e88587c96ef70e270", "fdce36b4c1e2489ea2536bfab5427ba2", "dd3aefd988f44f7d869b15d5c02a048c", "532298552d0a4bde9af07c3c7477817e", "1c9e6420a3a5454ca46a34c1795b4324", "b9d086ce5d704868976710bd2d0ff6f5", "400d8d393dcd4d8f9636b929e21340d3", "cf82082e746f475893dc78bc2a6797c5", "56b20b27982f4dcaa02a9dc50bd904d7", "f80c90734b3848d1b8de12ac506a0964", "c17ff94a24d0406da47d288a4909c1ca", "483f00b6b6e5481ebb3300a3688ceb94", "f3e79329120a4746ab7f7562a670eced", "c3904dd2df8b4e1a951a017fd4f461b1", "eb00b4846a914476b0c700961ec175bf", "67d208d93aff4fca813709a09121003e", "c0a30097a841484a941fa6e1ed15455e", "a7f512462dc947f98faf971b586174f4", "2fa66890a5204be5ae635b496b567c7b", "ad181c937495464f9c7e6645a206bd16", "bf45e91bfb5242e7afafa5f1307c08b5", "a733cbf17a704ea1ac3728b38305000c", "6736c6a950194779bd53e83a357da672", "89bec3b17b7b432aa3c9d4a3294448c8", "e43ca38b93814b3d85bc10633da6818c", "20617e0dffd94b409a5be5b8232a4d70", "47e5234b4a394ef8b4b5853c82490daf", "272f585e99f043b28ad27e28a37316c2", "a4bb429a6a4c47e783fafdd04ffe2024", "12da388e30ed45909b529261835c60a5", "26c9c6f594934d669a67b583604b45ef", "354854c4bbed4ba0ae4446de3d84dab3", "a30740fb9efc4d7889106aeea3dafc3f", "077cdcc5780f4d4aaefb954128fb7f81", "2f99242950a14d968a7e718e807e5500", "7f5d38cb1b494d13a7b90a8588513d6e", "3dcdf151f0bf46448440b33007516f4f", "8fab3345490e4e339c1c89716004104a", "d1e356e938074cc7a2a4d1d7369122fb", "0d5d321a0bda4629beee9398786b6879", "92e87d9e442f434996662f618b6e1910", "a332bf80f20041349dcc149508824987", "947f01f5dce44808a4495320c2703c97", "0d292670477e49c5a59463bf8ceb8653", "f8ecabd6a3174f268e5684cc5f39ad4f", "4ab73105f6704e5298cd8108bea2dafe", "2c920823612a467487f8c48be703ca65", "ab631562d0dc46bdb1c73cae1945818c", "d67e327188684a7e8926f9f80347b635", "7cfad07d9be643e3a452afe7a8efe75e", "2bc089c8819b4c91b289cb7220600e6a", "7fd04b551b9b486ea95d8c74b95b6bbb", "44d786e1aba0420b8cc303e7ff700fb7", "e2d0421c89cc420bb87ec1fdbd0b5429", "4573412cece0487a8f9ae652c13c508c", "f36109440d304ad096a0166abbd55488", "af3aa89b6bbe4782b6af5bfc362ace2d", "a7b5c0dcf84a4f0bab772360b0a0bffa", "be3e1efc43ec481e9d2692ad36d5feb8", "320da12652b443528a8fd1be4b8dc63c", "1fbed2d3ab7b405591de10a8311d2d3e", "c25e9549cfe94bfaaa9c332a17efceb9", "2585b258cce34f5b8f5d748f7ebec0b7", "35fd5bad7c474520967a249488d242d8", "a74007f4d88148919cd73996a18834af", "ae9bc9d26b964dc5a48e3f0b02b9eaff", "f2c3a51fc0f94d7b8365fd828221b8bd", "72f612f166164a129a1deab8b239bc56", "f2f6fca188394021baacda25b372bbcf", "22d68fdfdd074ef5ad7f5e9753f60da9", "49a68c8f7b654d3d9c6e4bf670a4fd09", "d9f268be66c941de80f9fbe7f761895f", "1e849d0c61b64269a545e6fae571a3c1", "5ce7bf03c9084fda988a874499b60ced", "da6767ce458a42a0ad57e72c337615eb", "b4ee9a2fffbb43bbb516d51c76ea0024", "910f32217ca148569e2275bd6e80ee42", "554113221b4a41739bb29c27489cedf8", "e89a2cf7de024720ad55a4c303e63ae1", "f506c270936649afa1e0740aea09a68a", "fea185d942104d419266490de49e0e6b", "4d71808e2482405cb3a48cdaf9d0c3ce", "5c1ac29e2a234dfc86ef37e26db5d3e5", "005bef960a404cf0bc618983c8465ace", "fa7b9fa19aee40e5b74237ca98d6ac0c", "a71f152a18ad4fafb09ecb5809c05cf9", "ed11cbcb10634d2b8e3a0dbec3175ca6", "65d403eaa1a64f99924ea3feb2f23f3b", "648ffc629ad247a09cad739a79ae9940", "574b5ba5efea4697a0177c1cee243b14", "366195973d464b03ad4f546a972a70db", "730a0786e1074d7ca12b27f7ec95825d", "b3f7fe2042064e46bf004c0724dbf69d", "c77bcf15a4df402d9df71fe83299d01d", "d480ae6f75e3417daa699a463d2f13f9", "7c1962422ab24243b01ccb5d7ce2e837", "deead65756834d49a6f4e1efb33e4b96", "60c485bcb48b48b6a7ed34cbcd0f1814", "cd3e93916ad749f59a54ff3e806b4c25", "826ec6bd3fa6448285268f500db7bb53", "c92e1a59f4894849912fc2c71dc7ca6e", "92045e862cb64dda87b4ca43a481baf5", "5b4ca483822e4d4e9660dceba05a1149", "84b86e9c608948698329c68359f659ba", "d0904dd9b54d43128325964b0d365780", "e2656b0a48e94db183a7504f98a94434", "2973221736f742f1a6784f1ccfe107b9", "142dcec0a67b47388de223f0b4e4ff64", "f6582a63e4ea45f0ad4ff4ea2b58f2c8", "a3a856a330dd49edb24dbd19d088df85", "6a719486b10940008c6ecddd2ca5b1d2", "048af3eeeb4346e195980c122832b82f", "5a58b50548d94103b29053fbeca3e6e7", "3c7878713b0f4536af9cffef753c1b70", "b95126cdb0ad49738ba1d20122b11321", "5c53e90676094dd2a24aa5e3b28ce85c", "43146d4df1d34540a94b7116dec911ed", "ba9da8ec34c640539939af29d2d7f834", "006b174751954fd9853120663382f829", "e42f291e768a4ade9f80c3c280e1b867", "755e9f75c4d442a39e69f3fb58d45d37", "427411daa42b4eac818dd80c4cf7bbe7", "a91462a79b8b4b70869a8747a81856f9", "4474163f967945959d8cf6753094d0e5", "7f56d8f3257447e1826b8794228ec1fa", "fe7e2661aea54d039ae361baf48f08f0", "74667407ce13462a85711f13a4a54e3c", "b5a2f043780b42e4bae9668488203b4b", "3d47cd1e40e84513b550b174f77e4169", "65f3eabcf4074c0889e9fb4084248b52", "71682cb9d10f4ad2871b61b6807db98f", "da6184178b924fdd931dacbff898733d", "9caab83b53f74b6595ec29021f0b7d34", "495ab747c762446bba8789b30a0769da", "f52a1139cea8486c8620a672f6f83dcc", "3dc5907ea916474fa96b1b4123723f9f", "275ac33a934c4cd68b671b42580b685b", "80436fad8c0c46ca9a3e8f7ce1dff1ab", "e42a6fd3679e488dbb9b28cb37ae6fae", "9ab226df46b341578b52903ac6a6b4e7", "2abd8f1e9fba40c6a71a855411d4f4a6", "f33507c075b0427fb6d16ca5f3739647", "99e5689e3f0045329e660333a4b80421", "2040589fc1aa4f2cad742269d239d040", "4f94f6d5cb65415f88564845a18fe397", "d8c57676fa3c410f94cea59a88de18bd", "4d90782c33a049bf8367e549548817d9", "de07e6f60492476daad70c6aa50fc458", "d11862546c9c461b8fd8a4d93b895b27", "e257a36279c840a2ad79dd6a02809342", "b730f21f20b44a26b70a6118cc96efa7", "fe37c81afc7348eabce470662aad40c1", "c0a7cf499c9e4399a11bec625d5d8a65", "44a832985e144790842aa2316001042a", "30a1dfc6ea8c48fb8a44d955afff8bd9", "72576e1f134443f5a93c7c752884053c", "b3ab5c6e2dd5411f9c650fd77812b640", "226759542a2e489189c3828e6cd62730", "96c93618176445f4bc390894ae61a7f5", "7186aa017a2b416c99e5d27b11fe4e5c", "4bcdd86c3cf7469a84e3af63859e65b9", "2d07d523e2e44828ae6357afc1e67f60", "1371cfd81c104dc28cdff8b7bfaef6dd", "20dc7421a9594766904b35dcc967a65a", "1e7ed9b5aaf64310a20e99492abe0e31", "4ed5efbfd59044cda48519005fc03aa1", "df936633e7cb402682cdefd516ec261e", "b6a85f7e4934455f9c1a2d6433f638c3", "6b4c354cc4794491986f5ff73d900e0a", "5a558e5d23b04e51a14e2018a9f0cd1a", "aa4dd00a74a94fa5b10f3f327c490154", "b2c53cc756064ec79be2d006e5ed9457", "553fd4038de644a99178c37a976b65e2", "a5b4001c73924fad95956963ef0a96ae", "d2a5b4cf16f84d32ae9660dda0acd17d", "51876a46b62c4d3abd08849c39ef3a2c", "7c78244b6fc24f259d5c80e978d75eba", "853e6969e04a4a62969c15cc98b0dcaa", "d0a1318715a4493385c6d0908b1afc87", "db22f4bcc5c44402bebc10ecfc5d4c0d", "c2835defa4a6469e89ddea637b44984f", "1f7382034b924d3a995a954958d52c82", "6cb7e2d6a00d488a876835fb2edfab03", "26030f3fbf44489196e65acf3f34aeab", "22ecd5dadc0f4dabac8d6c70e9d61159", "fb05c3e7c86647b3b0df679cbcda1b97", "2e133611f2a7423e92454a4bd55f4539", "30fce49326ee49aa94994f36250da32c", "66826178a3164366939cdc61f5e8c4ec", "502294de61994a21ada4d958f45db0e8", "f0139c431a414c03b393a33c4d03e29f", "c28f1de025794345b2e5ea0a9c9c2943", "ffdb88c051cf4e05958eb2e7ab60f757", "1a2a60701c92402193ad2001e1230eb3", "a7205db7641c49a69ab7f01d9e88a4da", "6d6cdedf745544649e6e8c12c70f7c89", "a5af12cadb5142609cf9e6777b8affdc", "ceb36795fb824c8f92b6c09dbb0b870c", "b6806870954b42e6be8cdc3615f927b6", "bf382ff0cc2a4211b9fbd123b885bd94", "e716868356d84277a6cbd025f21e04ce", "61b438186d8f45b5a1a1bbde1960e663", "f781f48a608d46cfa3bd09d86b524eec", "f3c9b5b8a013481bb1ce0dbc9ff3ffb3", "6a66304d66004e649c4e79e91ce93bc7", "9ec6b51b2c614c0c81c2d50247d32e7e", "9760b1b673b14f5a8c6cf3cf0feccea3", "0cada0bf267142228a4c760aff6f34b6" ] }, "outputId": "3ebf261d-9a6a-425f-a0a9-dc8d567019e4" }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "README.md: 0%| | 0.00/3.96k [00:00" ], "text/html": [ "\n", "
\n", " \n", " \n", " [936/936 1:28:11, Epoch 5/6]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
EpochTraining LossValidation LossRouge1Rouge2RougelRougelsumGen Len
1No log2.70142440.97510013.08580022.72180034.009600138.116700
2No log2.63396341.98540013.91700023.44030035.034800138.423300
3No log2.60816742.38030014.12180023.63330035.320200136.483300
42.8575002.59326242.10680014.13760023.75810035.028400137.106700
52.8575002.58840342.29790013.95030023.66250035.186700137.140000

" ] }, "metadata": {} }, { "output_type": "execute_result", "data": { "text/plain": [ "TrainOutput(global_step=936, training_loss=2.759424845377604, metrics={'train_runtime': 5295.7373, 'train_samples_per_second': 5.665, 'train_steps_per_second': 0.177, 'total_flos': 6.486527958633677e+16, 'train_loss': 2.759424845377604, 'epoch': 5.9664})" ] }, "metadata": {}, "execution_count": 21 } ], "source": [ "torch.cuda.empty_cache()\n", "\n", "training_args = Seq2SeqTrainingArguments(\n", " output_dir=HF_REPO_PATH.split(\"/\")[-1],\n", " evaluation_strategy=\"epoch\",\n", " save_strategy=\"epoch\",\n", " learning_rate=2e-4,\n", " warmup_ratio=0.1,\n", " per_device_train_batch_size=8,\n", " per_device_eval_batch_size=8,\n", " gradient_accumulation_steps=4,\n", " lr_scheduler_type=\"cosine\",\n", " weight_decay=0.01,\n", " num_train_epochs=6,\n", " predict_with_generate=True,\n", " load_best_model_at_end=True,\n", " metric_for_best_model=\"rouge2\",\n", " greater_is_better=True,\n", " fp16=True,\n", " push_to_hub=False,\n", " report_to=\"none\",\n", " disable_tqdm=False,\n", ")\n", "\n", "trainer = Seq2SeqTrainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=tokenized_arxiv_train,\n", " eval_dataset=tokenized_arxiv_valid,\n", " tokenizer=tokenizer,\n", " data_collator=data_collator,\n", " compute_metrics=compute_metrics,\n", ")\n", "\n", "trainer.train()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "47u4FyoX71k2", "colab": { "base_uri": "https://localhost:8080/", "height": 101, "referenced_widgets": [ "4a4aa54004fc4b5a888b22b10e928573", "9904a43f83254a8f93ad11ba4711980d", "3c62ae3d2dc2428eac88a1ae5215160a", "7ea9233fd4b145c088e4145033da08a6", "896e9b913fa74baa81f29510eabfb886", "242198d60272429da927efca1703f522", "54d7a5a208b544e88d5d582948e8f541", "18efd5066fda422eb6492fedcbc8b756", "3d8b2fd48a3e4a9baf6a46bc91533b66", "da3f16a81bce49ef9671492d5c7217c3", "fcd966de3dc84aba8cd8b25d691ca585" ] }, "outputId": "827433cb-14d3-4f2d-f798-d74cee6277bf" }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "adapter_model.safetensors: 0%| | 0.00/4.74M [00:00