diff --git "a/arxiv_summarization_lora.ipynb" "b/arxiv_summarization_lora.ipynb" new file mode 100644--- /dev/null +++ "b/arxiv_summarization_lora.ipynb" @@ -0,0 +1,12792 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "XO1Y8Apj23Uz" + }, + "source": [ + "# Preparing the environment" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Z9VQRXWGwLQS", + "outputId": "00e1b986-7cb0-48e5-f5f1-cae5384237ce" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m485.4/485.4 kB\u001b[0m \u001b[31m9.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m84.0/84.0 kB\u001b[0m \u001b[31m3.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m6.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m143.5/143.5 kB\u001b[0m \u001b[31m8.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.8/194.8 kB\u001b[0m \u001b[31m7.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Building wheel for rouge_score (setup.py) ... \u001b[?25l\u001b[?25hdone\n" + ] + } + ], + "source": [ + "!pip install datasets evaluate rouge_score --quiet" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OpAiX--01Etf" + }, + "outputs": [], + "source": [ + "import warnings\n", + "\n", + "import evaluate\n", + "import numpy as np\n", + "import torch\n", + "from datasets import load_dataset\n", + "from peft import LoraConfig, PeftModel, get_peft_model\n", + "from transformers import (\n", + " AutoModelForSeq2SeqLM,\n", + " AutoTokenizer,\n", + " DataCollatorForSeq2Seq,\n", + " Seq2SeqTrainer,\n", + " Seq2SeqTrainingArguments,\n", + " logging,\n", + ")\n", + "\n", + "CKPT_PATH = \"facebook/bart-large-cnn\" # Path to the fine-tuned base model\n", + "DATASET_PATH = \"ccdv/arxiv-summarization\" # Path to the dataset to be used for fine-tuning\n", + "HF_REPO_PATH = \"spolivin/bart-arxiv-lora\" # Path to the repo where LoRA adapters are saved\n", + "\n", + "warnings.filterwarnings(\"ignore\")\n", + "logging.set_verbosity_error()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KYyo2SOm26dQ" + }, + "source": [ + "# Loading and preprocessing data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "CDtuyM-WwRwx", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 689, + "referenced_widgets": [ + "7415870a03b6428094d93967bc9c1bf7", + "147c667abe374114babb47b3538ff650", + "74d74a270a3040b3bbd7064cc09e1c2d", + "153fc063223443a38520be05a5b9d23c", + "9491e1872fc04940acca0a0c54bf0107", + "3153daf2369847f8921b0501a30ad0f9", + "baf687a8771d4e1882f1561b3392be36", + "942a70ea07a54764a63e21a225d967c7", + "5c1fb52c18524ca99a937c3967804903", + "8a5d555401934fa78d47a4070389cdc4", + "8e130a2693714e8c9a8a887aa102ca2e", + "e1176b3b24a94e7ca8e89925f7a6f2f6", + "54a131a63c20490da2de6dfc9950acec", + "57b6dfc0003846d19e7d04aaede3284b", + "e7c7cd775e7243ddafd47bbeb48b43c0", + "6e057db395f642c79133e20fdae6583f", + "3a3fda338c5442f18ca772308b134603", + "0db9b275dc8142c6bff0f4954556437e", + "54225bfa87ff4e94a36573faac0cf094", + "5d05b48ad5b44f809b722a6bfa1641d7", + "44db00544fe24dcaa1ebcc9797c1a128", + "1f7f18f3346a40e7b65e6f41854a1db6", + "5af00a93f46b47e2b3c7654d414089f9", + "1b593e5c1e9d4b9e88587c96ef70e270", + "fdce36b4c1e2489ea2536bfab5427ba2", + "dd3aefd988f44f7d869b15d5c02a048c", + "532298552d0a4bde9af07c3c7477817e", + "1c9e6420a3a5454ca46a34c1795b4324", + "b9d086ce5d704868976710bd2d0ff6f5", + "400d8d393dcd4d8f9636b929e21340d3", + "cf82082e746f475893dc78bc2a6797c5", + "56b20b27982f4dcaa02a9dc50bd904d7", + "f80c90734b3848d1b8de12ac506a0964", + "c17ff94a24d0406da47d288a4909c1ca", + "483f00b6b6e5481ebb3300a3688ceb94", + "f3e79329120a4746ab7f7562a670eced", + "c3904dd2df8b4e1a951a017fd4f461b1", + "eb00b4846a914476b0c700961ec175bf", + "67d208d93aff4fca813709a09121003e", + "c0a30097a841484a941fa6e1ed15455e", + "a7f512462dc947f98faf971b586174f4", + "2fa66890a5204be5ae635b496b567c7b", + "ad181c937495464f9c7e6645a206bd16", + "bf45e91bfb5242e7afafa5f1307c08b5", + "a733cbf17a704ea1ac3728b38305000c", + "6736c6a950194779bd53e83a357da672", + "89bec3b17b7b432aa3c9d4a3294448c8", + "e43ca38b93814b3d85bc10633da6818c", + "20617e0dffd94b409a5be5b8232a4d70", + "47e5234b4a394ef8b4b5853c82490daf", + "272f585e99f043b28ad27e28a37316c2", + "a4bb429a6a4c47e783fafdd04ffe2024", + "12da388e30ed45909b529261835c60a5", + "26c9c6f594934d669a67b583604b45ef", + "354854c4bbed4ba0ae4446de3d84dab3", + "a30740fb9efc4d7889106aeea3dafc3f", + "077cdcc5780f4d4aaefb954128fb7f81", + "2f99242950a14d968a7e718e807e5500", + "7f5d38cb1b494d13a7b90a8588513d6e", + "3dcdf151f0bf46448440b33007516f4f", + "8fab3345490e4e339c1c89716004104a", + "d1e356e938074cc7a2a4d1d7369122fb", + "0d5d321a0bda4629beee9398786b6879", + "92e87d9e442f434996662f618b6e1910", + "a332bf80f20041349dcc149508824987", + "947f01f5dce44808a4495320c2703c97", + "0d292670477e49c5a59463bf8ceb8653", + "f8ecabd6a3174f268e5684cc5f39ad4f", + "4ab73105f6704e5298cd8108bea2dafe", + "2c920823612a467487f8c48be703ca65", + "ab631562d0dc46bdb1c73cae1945818c", + "d67e327188684a7e8926f9f80347b635", + "7cfad07d9be643e3a452afe7a8efe75e", + "2bc089c8819b4c91b289cb7220600e6a", + "7fd04b551b9b486ea95d8c74b95b6bbb", + "44d786e1aba0420b8cc303e7ff700fb7", + "e2d0421c89cc420bb87ec1fdbd0b5429", + "4573412cece0487a8f9ae652c13c508c", + "f36109440d304ad096a0166abbd55488", + "af3aa89b6bbe4782b6af5bfc362ace2d", + "a7b5c0dcf84a4f0bab772360b0a0bffa", + "be3e1efc43ec481e9d2692ad36d5feb8", + "320da12652b443528a8fd1be4b8dc63c", + "1fbed2d3ab7b405591de10a8311d2d3e", + "c25e9549cfe94bfaaa9c332a17efceb9", + "2585b258cce34f5b8f5d748f7ebec0b7", + "35fd5bad7c474520967a249488d242d8", + "a74007f4d88148919cd73996a18834af", + "ae9bc9d26b964dc5a48e3f0b02b9eaff", + "f2c3a51fc0f94d7b8365fd828221b8bd", + "72f612f166164a129a1deab8b239bc56", + "f2f6fca188394021baacda25b372bbcf", + "22d68fdfdd074ef5ad7f5e9753f60da9", + "49a68c8f7b654d3d9c6e4bf670a4fd09", + "d9f268be66c941de80f9fbe7f761895f", + "1e849d0c61b64269a545e6fae571a3c1", + "5ce7bf03c9084fda988a874499b60ced", + "da6767ce458a42a0ad57e72c337615eb", + "b4ee9a2fffbb43bbb516d51c76ea0024", + "910f32217ca148569e2275bd6e80ee42", + "554113221b4a41739bb29c27489cedf8", + "e89a2cf7de024720ad55a4c303e63ae1", + "f506c270936649afa1e0740aea09a68a", + "fea185d942104d419266490de49e0e6b", + "4d71808e2482405cb3a48cdaf9d0c3ce", + "5c1ac29e2a234dfc86ef37e26db5d3e5", + "005bef960a404cf0bc618983c8465ace", + "fa7b9fa19aee40e5b74237ca98d6ac0c", + "a71f152a18ad4fafb09ecb5809c05cf9", + "ed11cbcb10634d2b8e3a0dbec3175ca6", + "65d403eaa1a64f99924ea3feb2f23f3b", + "648ffc629ad247a09cad739a79ae9940", + "574b5ba5efea4697a0177c1cee243b14", + "366195973d464b03ad4f546a972a70db", + "730a0786e1074d7ca12b27f7ec95825d", + "b3f7fe2042064e46bf004c0724dbf69d", + "c77bcf15a4df402d9df71fe83299d01d", + "d480ae6f75e3417daa699a463d2f13f9", + "7c1962422ab24243b01ccb5d7ce2e837", + "deead65756834d49a6f4e1efb33e4b96", + "60c485bcb48b48b6a7ed34cbcd0f1814", + "cd3e93916ad749f59a54ff3e806b4c25", + "826ec6bd3fa6448285268f500db7bb53", + "c92e1a59f4894849912fc2c71dc7ca6e", + "92045e862cb64dda87b4ca43a481baf5", + "5b4ca483822e4d4e9660dceba05a1149", + "84b86e9c608948698329c68359f659ba", + "d0904dd9b54d43128325964b0d365780", + "e2656b0a48e94db183a7504f98a94434", + "2973221736f742f1a6784f1ccfe107b9", + "142dcec0a67b47388de223f0b4e4ff64", + "f6582a63e4ea45f0ad4ff4ea2b58f2c8", + "a3a856a330dd49edb24dbd19d088df85", + "6a719486b10940008c6ecddd2ca5b1d2", + "048af3eeeb4346e195980c122832b82f", + "5a58b50548d94103b29053fbeca3e6e7", + "3c7878713b0f4536af9cffef753c1b70", + "b95126cdb0ad49738ba1d20122b11321", + "5c53e90676094dd2a24aa5e3b28ce85c", + "43146d4df1d34540a94b7116dec911ed", + "ba9da8ec34c640539939af29d2d7f834", + "006b174751954fd9853120663382f829", + "e42f291e768a4ade9f80c3c280e1b867", + "755e9f75c4d442a39e69f3fb58d45d37", + "427411daa42b4eac818dd80c4cf7bbe7", + "a91462a79b8b4b70869a8747a81856f9", + "4474163f967945959d8cf6753094d0e5", + "7f56d8f3257447e1826b8794228ec1fa", + "fe7e2661aea54d039ae361baf48f08f0", + "74667407ce13462a85711f13a4a54e3c", + "b5a2f043780b42e4bae9668488203b4b", + "3d47cd1e40e84513b550b174f77e4169", + "65f3eabcf4074c0889e9fb4084248b52", + "71682cb9d10f4ad2871b61b6807db98f", + "da6184178b924fdd931dacbff898733d", + "9caab83b53f74b6595ec29021f0b7d34", + "495ab747c762446bba8789b30a0769da", + "f52a1139cea8486c8620a672f6f83dcc", + "3dc5907ea916474fa96b1b4123723f9f", + "275ac33a934c4cd68b671b42580b685b", + "80436fad8c0c46ca9a3e8f7ce1dff1ab", + "e42a6fd3679e488dbb9b28cb37ae6fae", + "9ab226df46b341578b52903ac6a6b4e7", + "2abd8f1e9fba40c6a71a855411d4f4a6", + "f33507c075b0427fb6d16ca5f3739647", + "99e5689e3f0045329e660333a4b80421", + "2040589fc1aa4f2cad742269d239d040", + "4f94f6d5cb65415f88564845a18fe397", + "d8c57676fa3c410f94cea59a88de18bd", + "4d90782c33a049bf8367e549548817d9", + "de07e6f60492476daad70c6aa50fc458", + "d11862546c9c461b8fd8a4d93b895b27", + "e257a36279c840a2ad79dd6a02809342", + "b730f21f20b44a26b70a6118cc96efa7", + "fe37c81afc7348eabce470662aad40c1", + "c0a7cf499c9e4399a11bec625d5d8a65", + "44a832985e144790842aa2316001042a", + "30a1dfc6ea8c48fb8a44d955afff8bd9", + "72576e1f134443f5a93c7c752884053c", + "b3ab5c6e2dd5411f9c650fd77812b640", + "226759542a2e489189c3828e6cd62730", + "96c93618176445f4bc390894ae61a7f5", + "7186aa017a2b416c99e5d27b11fe4e5c", + "4bcdd86c3cf7469a84e3af63859e65b9", + "2d07d523e2e44828ae6357afc1e67f60", + "1371cfd81c104dc28cdff8b7bfaef6dd", + "20dc7421a9594766904b35dcc967a65a", + "1e7ed9b5aaf64310a20e99492abe0e31", + "4ed5efbfd59044cda48519005fc03aa1", + "df936633e7cb402682cdefd516ec261e", + "b6a85f7e4934455f9c1a2d6433f638c3", + "6b4c354cc4794491986f5ff73d900e0a", + "5a558e5d23b04e51a14e2018a9f0cd1a", + "aa4dd00a74a94fa5b10f3f327c490154", + "b2c53cc756064ec79be2d006e5ed9457", + "553fd4038de644a99178c37a976b65e2", + "a5b4001c73924fad95956963ef0a96ae", + "d2a5b4cf16f84d32ae9660dda0acd17d", + "51876a46b62c4d3abd08849c39ef3a2c", + "7c78244b6fc24f259d5c80e978d75eba", + "853e6969e04a4a62969c15cc98b0dcaa", + "d0a1318715a4493385c6d0908b1afc87", + "db22f4bcc5c44402bebc10ecfc5d4c0d", + "c2835defa4a6469e89ddea637b44984f", + "1f7382034b924d3a995a954958d52c82", + "6cb7e2d6a00d488a876835fb2edfab03", + "26030f3fbf44489196e65acf3f34aeab", + "22ecd5dadc0f4dabac8d6c70e9d61159", + "fb05c3e7c86647b3b0df679cbcda1b97", + "2e133611f2a7423e92454a4bd55f4539", + "30fce49326ee49aa94994f36250da32c", + "66826178a3164366939cdc61f5e8c4ec", + "502294de61994a21ada4d958f45db0e8", + "f0139c431a414c03b393a33c4d03e29f", + "c28f1de025794345b2e5ea0a9c9c2943", + "ffdb88c051cf4e05958eb2e7ab60f757", + "1a2a60701c92402193ad2001e1230eb3", + "a7205db7641c49a69ab7f01d9e88a4da", + "6d6cdedf745544649e6e8c12c70f7c89", + "a5af12cadb5142609cf9e6777b8affdc", + "ceb36795fb824c8f92b6c09dbb0b870c", + "b6806870954b42e6be8cdc3615f927b6", + "bf382ff0cc2a4211b9fbd123b885bd94", + "e716868356d84277a6cbd025f21e04ce", + "61b438186d8f45b5a1a1bbde1960e663", + "f781f48a608d46cfa3bd09d86b524eec", + "f3c9b5b8a013481bb1ce0dbc9ff3ffb3", + "6a66304d66004e649c4e79e91ce93bc7", + "9ec6b51b2c614c0c81c2d50247d32e7e", + "9760b1b673b14f5a8c6cf3cf0feccea3", + "0cada0bf267142228a4c760aff6f34b6" + ] + }, + "outputId": "3ebf261d-9a6a-425f-a0a9-dc8d567019e4" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "README.md: 0%| | 0.00/3.96k [00:00" + ], + "text/html": [ + "\n", + "
\n", + " \n", + " \n", + " [936/936 1:28:11, Epoch 5/6]\n", + "
\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
EpochTraining LossValidation LossRouge1Rouge2RougelRougelsumGen Len
1No log2.70142440.97510013.08580022.72180034.009600138.116700
2No log2.63396341.98540013.91700023.44030035.034800138.423300
3No log2.60816742.38030014.12180023.63330035.320200136.483300
42.8575002.59326242.10680014.13760023.75810035.028400137.106700
52.8575002.58840342.29790013.95030023.66250035.186700137.140000

" + ] + }, + "metadata": {} + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "TrainOutput(global_step=936, training_loss=2.759424845377604, metrics={'train_runtime': 5295.7373, 'train_samples_per_second': 5.665, 'train_steps_per_second': 0.177, 'total_flos': 6.486527958633677e+16, 'train_loss': 2.759424845377604, 'epoch': 5.9664})" + ] + }, + "metadata": {}, + "execution_count": 21 + } + ], + "source": [ + "torch.cuda.empty_cache()\n", + "\n", + "training_args = Seq2SeqTrainingArguments(\n", + " output_dir=HF_REPO_PATH.split(\"/\")[-1],\n", + " evaluation_strategy=\"epoch\",\n", + " save_strategy=\"epoch\",\n", + " learning_rate=2e-4,\n", + " warmup_ratio=0.1,\n", + " per_device_train_batch_size=8,\n", + " per_device_eval_batch_size=8,\n", + " gradient_accumulation_steps=4,\n", + " lr_scheduler_type=\"cosine\",\n", + " weight_decay=0.01,\n", + " num_train_epochs=6,\n", + " predict_with_generate=True,\n", + " load_best_model_at_end=True,\n", + " metric_for_best_model=\"rouge2\",\n", + " greater_is_better=True,\n", + " fp16=True,\n", + " push_to_hub=False,\n", + " report_to=\"none\",\n", + " disable_tqdm=False,\n", + ")\n", + "\n", + "trainer = Seq2SeqTrainer(\n", + " model=model,\n", + " args=training_args,\n", + " train_dataset=tokenized_arxiv_train,\n", + " eval_dataset=tokenized_arxiv_valid,\n", + " tokenizer=tokenizer,\n", + " data_collator=data_collator,\n", + " compute_metrics=compute_metrics,\n", + ")\n", + "\n", + "trainer.train()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "47u4FyoX71k2", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 101, + "referenced_widgets": [ + "4a4aa54004fc4b5a888b22b10e928573", + "9904a43f83254a8f93ad11ba4711980d", + "3c62ae3d2dc2428eac88a1ae5215160a", + "7ea9233fd4b145c088e4145033da08a6", + "896e9b913fa74baa81f29510eabfb886", + "242198d60272429da927efca1703f522", + "54d7a5a208b544e88d5d582948e8f541", + "18efd5066fda422eb6492fedcbc8b756", + "3d8b2fd48a3e4a9baf6a46bc91533b66", + "da3f16a81bce49ef9671492d5c7217c3", + "fcd966de3dc84aba8cd8b25d691ca585" + ] + }, + "outputId": "827433cb-14d3-4f2d-f798-d74cee6277bf" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "adapter_model.safetensors: 0%| | 0.00/4.74M [00:00