Matthew Hollings commited on
Commit
3d96507
·
1 Parent(s): 4565d47

Use my fine-tuned model from huggingface

Browse files
Files changed (3) hide show
  1. .gitignore +2 -1
  2. app.py +1 -1
  3. fine-tuning-for-casual-language-model.ipynb +119 -151
.gitignore CHANGED
@@ -2,4 +2,5 @@ __pycache__
2
  flagged/
3
  gutenberg-dammit-files-v002.zip
4
  tmp_trainer
5
- *.gz
 
 
2
  flagged/
3
  gutenberg-dammit-files-v002.zip
4
  tmp_trainer
5
+ *.gz
6
+ gpt2-poetry-model
app.py CHANGED
@@ -3,7 +3,7 @@ import gradio as gr
3
  from transformers import pipeline
4
 
5
  # Set up the generatove model transformer pipeline
6
- generator = pipeline("text-generation", model="tmp_trainer")
7
 
8
  # A sequence of lines both those typed in and the line so far
9
  # when save is clicked the txt file is downloaded
 
3
  from transformers import pipeline
4
 
5
  # Set up the generatove model transformer pipeline
6
+ generator = pipeline("text-generation", model="matthh/gpt2-poetry-model")
7
 
8
  # A sequence of lines both those typed in and the line so far
9
  # when save is clicked the txt file is downloaded
fine-tuning-for-casual-language-model.ipynb CHANGED
@@ -2,7 +2,7 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": null,
6
  "metadata": {},
7
  "outputs": [],
8
  "source": [
@@ -11,7 +11,7 @@
11
  },
12
  {
13
  "cell_type": "code",
14
- "execution_count": 43,
15
  "metadata": {},
16
  "outputs": [],
17
  "source": [
@@ -43,29 +43,16 @@
43
  },
44
  {
45
  "cell_type": "code",
46
- "execution_count": 4,
47
  "metadata": {},
48
- "outputs": [
49
- {
50
- "ename": "ImportError",
51
- "evalue": "This example requires a source install from HuggingFace Transformers (see `https://huggingface.co/transformers/installation.html#installing-from-source`), but the version found is 4.11.3.\nCheck out https://huggingface.co/transformers/examples.html for the examples corresponding to other versions of HuggingFace Transformers.",
52
- "output_type": "error",
53
- "traceback": [
54
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
55
- "\u001b[0;31mImportError\u001b[0m Traceback (most recent call last)",
56
- "Cell \u001b[0;32mIn [4], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mcheck_min_version\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m4.23.0.dev0\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n",
57
- "File \u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/augmented_poetry/lib/python3.8/site-packages/transformers/utils/__init__.py:32\u001b[0m, in \u001b[0;36mcheck_min_version\u001b[0;34m(min_version)\u001b[0m\n\u001b[1;32m 30\u001b[0m error_message \u001b[39m=\u001b[39m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mThis example requires a minimum version of \u001b[39m\u001b[39m{\u001b[39;00mmin_version\u001b[39m}\u001b[39;00m\u001b[39m,\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 31\u001b[0m error_message \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39m but the version found is \u001b[39m\u001b[39m{\u001b[39;00m__version__\u001b[39m}\u001b[39;00m\u001b[39m.\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m\"\u001b[39m\n\u001b[0;32m---> 32\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mImportError\u001b[39;00m(\n\u001b[1;32m 33\u001b[0m error_message\n\u001b[1;32m 34\u001b[0m \u001b[39m+\u001b[39m (\n\u001b[1;32m 35\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mCheck out https://huggingface.co/transformers/examples.html for the examples corresponding to other \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 36\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mversions of HuggingFace Transformers.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 37\u001b[0m )\n\u001b[1;32m 38\u001b[0m )\n",
58
- "\u001b[0;31mImportError\u001b[0m: This example requires a source install from HuggingFace Transformers (see `https://huggingface.co/transformers/installation.html#installing-from-source`), but the version found is 4.11.3.\nCheck out https://huggingface.co/transformers/examples.html for the examples corresponding to other versions of HuggingFace Transformers."
59
- ]
60
- }
61
- ],
62
  "source": [
63
  "# check_min_version(\"4.23.0.dev0\")"
64
  ]
65
  },
66
  {
67
  "cell_type": "code",
68
- "execution_count": 9,
69
  "metadata": {},
70
  "outputs": [],
71
  "source": [
@@ -74,7 +61,7 @@
74
  },
75
  {
76
  "cell_type": "code",
77
- "execution_count": 5,
78
  "metadata": {},
79
  "outputs": [],
80
  "source": [
@@ -90,90 +77,23 @@
90
  },
91
  {
92
  "cell_type": "code",
93
- "execution_count": 10,
94
  "metadata": {},
95
  "outputs": [
96
  {
97
  "name": "stderr",
98
  "output_type": "stream",
99
  "text": [
100
- "Using custom data configuration merve--poetry-ca9a13ef5858cc3a\n"
101
- ]
102
- },
103
- {
104
- "name": "stdout",
105
- "output_type": "stream",
106
- "text": [
107
- "Downloading and preparing dataset csv/merve--poetry to /Users/matth/.cache/huggingface/datasets/merve___csv/merve--poetry-ca9a13ef5858cc3a/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a...\n"
108
  ]
109
  },
110
  {
111
  "data": {
112
  "application/vnd.jupyter.widget-view+json": {
113
- "model_id": "ed56ee6b324647798b19ac7bf5accc40",
114
- "version_major": 2,
115
- "version_minor": 0
116
- },
117
- "text/plain": [
118
- "Downloading data files: 0%| | 0/1 [00:00<?, ?it/s]"
119
- ]
120
- },
121
- "metadata": {},
122
- "output_type": "display_data"
123
- },
124
- {
125
- "data": {
126
- "application/vnd.jupyter.widget-view+json": {
127
- "model_id": "32c10441ff20404cb153f6b27f16a829",
128
- "version_major": 2,
129
- "version_minor": 0
130
- },
131
- "text/plain": [
132
- "Downloading data: 0%| | 0.00/606k [00:00<?, ?B/s]"
133
- ]
134
- },
135
- "metadata": {},
136
- "output_type": "display_data"
137
- },
138
- {
139
- "data": {
140
- "application/vnd.jupyter.widget-view+json": {
141
- "model_id": "7ca47bc06937463e91d3948d7703ac64",
142
- "version_major": 2,
143
- "version_minor": 0
144
- },
145
- "text/plain": [
146
- "Extracting data files: 0%| | 0/1 [00:00<?, ?it/s]"
147
- ]
148
- },
149
- "metadata": {},
150
- "output_type": "display_data"
151
- },
152
- {
153
- "data": {
154
- "application/vnd.jupyter.widget-view+json": {
155
- "model_id": "1631dbdc53d04b14a8a7733883bbd1cc",
156
- "version_major": 2,
157
- "version_minor": 0
158
- },
159
- "text/plain": [
160
- "0 tables [00:00, ? tables/s]"
161
- ]
162
- },
163
- "metadata": {},
164
- "output_type": "display_data"
165
- },
166
- {
167
- "name": "stdout",
168
- "output_type": "stream",
169
- "text": [
170
- "Dataset csv downloaded and prepared to /Users/matth/.cache/huggingface/datasets/merve___csv/merve--poetry-ca9a13ef5858cc3a/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a. Subsequent calls will reuse this data.\n"
171
- ]
172
- },
173
- {
174
- "data": {
175
- "application/vnd.jupyter.widget-view+json": {
176
- "model_id": "3c93229d66ad46d9a88da5f6a9528f2e",
177
  "version_major": 2,
178
  "version_minor": 0
179
  },
@@ -191,7 +111,7 @@
191
  },
192
  {
193
  "cell_type": "code",
194
- "execution_count": 12,
195
  "metadata": {},
196
  "outputs": [],
197
  "source": [
@@ -200,16 +120,18 @@
200
  },
201
  {
202
  "cell_type": "code",
203
- "execution_count": 13,
204
  "metadata": {},
205
  "outputs": [],
206
  "source": [
207
- "config = AutoConfig.from_pretrained('gpt2')"
 
 
208
  ]
209
  },
210
  {
211
  "cell_type": "code",
212
- "execution_count": 16,
213
  "metadata": {},
214
  "outputs": [
215
  {
@@ -218,7 +140,7 @@
218
  "Embedding(50257, 768)"
219
  ]
220
  },
221
- "execution_count": 16,
222
  "metadata": {},
223
  "output_type": "execute_result"
224
  }
@@ -228,12 +150,13 @@
228
  " \"gpt2\",\n",
229
  " config=config\n",
230
  ")\n",
 
231
  "model.resize_token_embeddings(len(tokenizer))"
232
  ]
233
  },
234
  {
235
  "cell_type": "code",
236
- "execution_count": 24,
237
  "metadata": {},
238
  "outputs": [
239
  {
@@ -245,7 +168,7 @@
245
  "})"
246
  ]
247
  },
248
- "execution_count": 24,
249
  "metadata": {},
250
  "output_type": "execute_result"
251
  }
@@ -256,7 +179,7 @@
256
  },
257
  {
258
  "cell_type": "code",
259
- "execution_count": 26,
260
  "metadata": {},
261
  "outputs": [
262
  {
@@ -265,7 +188,7 @@
265
  "'Mythology & Folklore'"
266
  ]
267
  },
268
- "execution_count": 26,
269
  "metadata": {},
270
  "output_type": "execute_result"
271
  }
@@ -276,7 +199,7 @@
276
  },
277
  {
278
  "cell_type": "code",
279
- "execution_count": 28,
280
  "metadata": {},
281
  "outputs": [
282
  {
@@ -290,7 +213,7 @@
290
  "})"
291
  ]
292
  },
293
- "execution_count": 28,
294
  "metadata": {},
295
  "output_type": "execute_result"
296
  }
@@ -301,7 +224,7 @@
301
  },
302
  {
303
  "cell_type": "code",
304
- "execution_count": 29,
305
  "metadata": {},
306
  "outputs": [],
307
  "source": [
@@ -312,7 +235,7 @@
312
  },
313
  {
314
  "cell_type": "code",
315
- "execution_count": 30,
316
  "metadata": {},
317
  "outputs": [],
318
  "source": [
@@ -330,7 +253,7 @@
330
  },
331
  {
332
  "cell_type": "code",
333
- "execution_count": 33,
334
  "metadata": {},
335
  "outputs": [],
336
  "source": [
@@ -341,29 +264,14 @@
341
  },
342
  {
343
  "cell_type": "code",
344
- "execution_count": 34,
345
  "metadata": {},
346
  "outputs": [
347
- {
348
- "data": {
349
- "application/vnd.jupyter.widget-view+json": {
350
- "model_id": "82c09dbdfa1a47d79607a4c9729fb286",
351
- "version_major": 2,
352
- "version_minor": 0
353
- },
354
- "text/plain": [
355
- "Running tokenizer on dataset: 0%| | 0/1 [00:00<?, ?ba/s]"
356
- ]
357
- },
358
- "metadata": {},
359
- "output_type": "display_data"
360
- },
361
  {
362
  "name": "stderr",
363
  "output_type": "stream",
364
  "text": [
365
- "Token indices sequence length is longer than the specified maximum sequence length for this model (7725 > 1024). Running this sequence through the model will result in indexing errors\n",
366
- "^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits before being passed to the model.\n"
367
  ]
368
  }
369
  ],
@@ -380,7 +288,7 @@
380
  },
381
  {
382
  "cell_type": "code",
383
- "execution_count": 39,
384
  "metadata": {},
385
  "outputs": [],
386
  "source": [
@@ -389,7 +297,7 @@
389
  },
390
  {
391
  "cell_type": "code",
392
- "execution_count": 41,
393
  "metadata": {},
394
  "outputs": [],
395
  "source": [
@@ -413,22 +321,15 @@
413
  },
414
  {
415
  "cell_type": "code",
416
- "execution_count": 44,
417
  "metadata": {},
418
  "outputs": [
419
  {
420
- "data": {
421
- "application/vnd.jupyter.widget-view+json": {
422
- "model_id": "ca2f64461e304df6aecb16e8cfcd42ac",
423
- "version_major": 2,
424
- "version_minor": 0
425
- },
426
- "text/plain": [
427
- "Grouping texts in chunks of 1024: 0%| | 0/1 [00:00<?, ?ba/s]"
428
- ]
429
- },
430
- "metadata": {},
431
- "output_type": "display_data"
432
  }
433
  ],
434
  "source": [
@@ -443,7 +344,7 @@
443
  },
444
  {
445
  "cell_type": "code",
446
- "execution_count": 46,
447
  "metadata": {},
448
  "outputs": [],
449
  "source": [
@@ -459,14 +360,30 @@
459
  },
460
  {
461
  "cell_type": "code",
462
- "execution_count": 47,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463
  "metadata": {},
464
  "outputs": [],
465
  "source": [
466
  "# Initialize our Trainer\n",
467
  "trainer = Trainer(\n",
468
  " model=model,\n",
469
- " # args=training_args,\n",
470
  " train_dataset=train_dataset,\n",
471
  " # eval_dataset=eval_dataset,\n",
472
  " tokenizer=tokenizer,\n",
@@ -483,7 +400,7 @@
483
  },
484
  {
485
  "cell_type": "code",
486
- "execution_count": 48,
487
  "metadata": {},
488
  "outputs": [
489
  {
@@ -558,18 +475,69 @@
558
  ],
559
  "source": [
560
  "# Training\n",
561
- "checkpoint = None\n",
562
- "train_result = trainer.train(resume_from_checkpoint=checkpoint)\n",
563
- "trainer.save_model() # Saves the tokenizer too for easy upload\n",
564
  "\n",
565
- "metrics = train_result.metrics\n",
566
  "\n",
567
- "max_train_samples = (len(train_dataset))\n",
568
- "metrics[\"train_samples\"] = min(max_train_samples, len(train_dataset))\n",
569
  "\n",
570
- "trainer.log_metrics(\"train\", metrics)\n",
571
- "trainer.save_metrics(\"train\", metrics)\n",
572
- "trainer.save_state()"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
573
  ]
574
  }
575
  ],
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 3,
6
  "metadata": {},
7
  "outputs": [],
8
  "source": [
 
11
  },
12
  {
13
  "cell_type": "code",
14
+ "execution_count": 4,
15
  "metadata": {},
16
  "outputs": [],
17
  "source": [
 
43
  },
44
  {
45
  "cell_type": "code",
46
+ "execution_count": 5,
47
  "metadata": {},
48
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  "source": [
50
  "# check_min_version(\"4.23.0.dev0\")"
51
  ]
52
  },
53
  {
54
  "cell_type": "code",
55
+ "execution_count": 6,
56
  "metadata": {},
57
  "outputs": [],
58
  "source": [
 
61
  },
62
  {
63
  "cell_type": "code",
64
+ "execution_count": 7,
65
  "metadata": {},
66
  "outputs": [],
67
  "source": [
 
77
  },
78
  {
79
  "cell_type": "code",
80
+ "execution_count": 8,
81
  "metadata": {},
82
  "outputs": [
83
  {
84
  "name": "stderr",
85
  "output_type": "stream",
86
  "text": [
87
+ "/opt/homebrew/Caskroom/miniforge/base/envs/augmented_poetry/lib/python3.8/site-packages/huggingface_hub/utils/_deprecation.py:97: FutureWarning: Deprecated argument(s) used in 'dataset_info': token. Will not be supported from version '0.12'.\n",
88
+ " warnings.warn(message, FutureWarning)\n",
89
+ "Using custom data configuration merve--poetry-ca9a13ef5858cc3a\n",
90
+ "Found cached dataset csv (/Users/matth/.cache/huggingface/datasets/merve___csv/merve--poetry-ca9a13ef5858cc3a/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a)\n"
 
 
 
 
91
  ]
92
  },
93
  {
94
  "data": {
95
  "application/vnd.jupyter.widget-view+json": {
96
+ "model_id": "67606d054e4a4b2f9ddf99f07c02c328",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  "version_major": 2,
98
  "version_minor": 0
99
  },
 
111
  },
112
  {
113
  "cell_type": "code",
114
+ "execution_count": 9,
115
  "metadata": {},
116
  "outputs": [],
117
  "source": [
 
120
  },
121
  {
122
  "cell_type": "code",
123
+ "execution_count": 10,
124
  "metadata": {},
125
  "outputs": [],
126
  "source": [
127
+ "config = AutoConfig.from_pretrained('gpt2')\n",
128
+ "\n",
129
+ "# max_seq_length"
130
  ]
131
  },
132
  {
133
  "cell_type": "code",
134
+ "execution_count": 11,
135
  "metadata": {},
136
  "outputs": [
137
  {
 
140
  "Embedding(50257, 768)"
141
  ]
142
  },
143
+ "execution_count": 11,
144
  "metadata": {},
145
  "output_type": "execute_result"
146
  }
 
150
  " \"gpt2\",\n",
151
  " config=config\n",
152
  ")\n",
153
+ "model.max_seq_length = 128\n",
154
  "model.resize_token_embeddings(len(tokenizer))"
155
  ]
156
  },
157
  {
158
  "cell_type": "code",
159
+ "execution_count": 12,
160
  "metadata": {},
161
  "outputs": [
162
  {
 
168
  "})"
169
  ]
170
  },
171
+ "execution_count": 12,
172
  "metadata": {},
173
  "output_type": "execute_result"
174
  }
 
179
  },
180
  {
181
  "cell_type": "code",
182
+ "execution_count": 13,
183
  "metadata": {},
184
  "outputs": [
185
  {
 
188
  "'Mythology & Folklore'"
189
  ]
190
  },
191
+ "execution_count": 13,
192
  "metadata": {},
193
  "output_type": "execute_result"
194
  }
 
199
  },
200
  {
201
  "cell_type": "code",
202
+ "execution_count": 14,
203
  "metadata": {},
204
  "outputs": [
205
  {
 
213
  "})"
214
  ]
215
  },
216
+ "execution_count": 14,
217
  "metadata": {},
218
  "output_type": "execute_result"
219
  }
 
224
  },
225
  {
226
  "cell_type": "code",
227
+ "execution_count": 15,
228
  "metadata": {},
229
  "outputs": [],
230
  "source": [
 
235
  },
236
  {
237
  "cell_type": "code",
238
+ "execution_count": 16,
239
  "metadata": {},
240
  "outputs": [],
241
  "source": [
 
253
  },
254
  {
255
  "cell_type": "code",
256
+ "execution_count": 17,
257
  "metadata": {},
258
  "outputs": [],
259
  "source": [
 
264
  },
265
  {
266
  "cell_type": "code",
267
+ "execution_count": 18,
268
  "metadata": {},
269
  "outputs": [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  {
271
  "name": "stderr",
272
  "output_type": "stream",
273
  "text": [
274
+ "Loading cached processed dataset at /Users/matth/.cache/huggingface/datasets/merve___csv/merve--poetry-ca9a13ef5858cc3a/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a/cache-62fd9c772e30c8d3.arrow\n"
 
275
  ]
276
  }
277
  ],
 
288
  },
289
  {
290
  "cell_type": "code",
291
+ "execution_count": 19,
292
  "metadata": {},
293
  "outputs": [],
294
  "source": [
 
297
  },
298
  {
299
  "cell_type": "code",
300
+ "execution_count": 20,
301
  "metadata": {},
302
  "outputs": [],
303
  "source": [
 
321
  },
322
  {
323
  "cell_type": "code",
324
+ "execution_count": 21,
325
  "metadata": {},
326
  "outputs": [
327
  {
328
+ "name": "stderr",
329
+ "output_type": "stream",
330
+ "text": [
331
+ "Loading cached processed dataset at /Users/matth/.cache/huggingface/datasets/merve___csv/merve--poetry-ca9a13ef5858cc3a/0.0.0/652c3096f041ee27b04d2232d41f10547a8fecda3e284a79a0ec4053c916ef7a/cache-88d7c64be469684a.arrow\n"
332
+ ]
 
 
 
 
 
 
 
333
  }
334
  ],
335
  "source": [
 
344
  },
345
  {
346
  "cell_type": "code",
347
+ "execution_count": 22,
348
  "metadata": {},
349
  "outputs": [],
350
  "source": [
 
360
  },
361
  {
362
  "cell_type": "code",
363
+ "execution_count": 25,
364
+ "metadata": {},
365
+ "outputs": [],
366
+ "source": [
367
+ "training_args = TrainingArguments(\n",
368
+ " output_dir=\"gpt2-poetry-model\", \n",
369
+ " overwrite_output_dir=True,\n",
370
+ " # per_gpu_train_batch_size=256\n",
371
+ " per_device_train_batch_size=16,\n",
372
+ " push_to_hub=True,\n",
373
+ " push_to_hub_token=\"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\"\n",
374
+ ")"
375
+ ]
376
+ },
377
+ {
378
+ "cell_type": "code",
379
+ "execution_count": 26,
380
  "metadata": {},
381
  "outputs": [],
382
  "source": [
383
  "# Initialize our Trainer\n",
384
  "trainer = Trainer(\n",
385
  " model=model,\n",
386
+ " args=training_args,\n",
387
  " train_dataset=train_dataset,\n",
388
  " # eval_dataset=eval_dataset,\n",
389
  " tokenizer=tokenizer,\n",
 
400
  },
401
  {
402
  "cell_type": "code",
403
+ "execution_count": null,
404
  "metadata": {},
405
  "outputs": [
406
  {
 
475
  ],
476
  "source": [
477
  "# Training\n",
478
+ "# checkpoint = None\n",
479
+ "# train_result = trainer.train(resume_from_checkpoint=checkpoint)\n",
480
+ "# trainer.save_model() # Saves the tokenizer too for easy upload\n",
481
  "\n",
482
+ "# metrics = train_result.metrics\n",
483
  "\n",
484
+ "# max_train_samples = (len(train_dataset))\n",
485
+ "# metrics[\"train_samples\"] = min(max_train_samples, len(train_dataset))\n",
486
  "\n",
487
+ "# trainer.log_metrics(\"train\", metrics)\n",
488
+ "# trainer.save_metrics(\"train\", metrics)\n",
489
+ "# trainer.save_state()\n",
490
+ "# # Upload the the hugging face hub for easy use in inference.\n",
491
+ "# trainer.push_to_hub()"
492
+ ]
493
+ },
494
+ {
495
+ "cell_type": "code",
496
+ "execution_count": 27,
497
+ "metadata": {},
498
+ "outputs": [
499
+ {
500
+ "data": {
501
+ "application/vnd.jupyter.widget-view+json": {
502
+ "model_id": "2cec8af2b332409bb857695a7b099653",
503
+ "version_major": 2,
504
+ "version_minor": 0
505
+ },
506
+ "text/plain": [
507
+ "VBox(children=(HTML(value='<center> <img\\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…"
508
+ ]
509
+ },
510
+ "metadata": {},
511
+ "output_type": "display_data"
512
+ },
513
+ {
514
+ "name": "stderr",
515
+ "output_type": "stream",
516
+ "text": [
517
+ "Saving model checkpoint to gpt2-poetry-model\n",
518
+ "Configuration saved in gpt2-poetry-model/config.json\n",
519
+ "Model weights saved in gpt2-poetry-model/pytorch_model.bin\n",
520
+ "tokenizer config file saved in gpt2-poetry-model/tokenizer_config.json\n",
521
+ "Special tokens file saved in gpt2-poetry-model/special_tokens_map.json\n"
522
+ ]
523
+ },
524
+ {
525
+ "ename": "AttributeError",
526
+ "evalue": "'Trainer' object has no attribute 'repo'",
527
+ "output_type": "error",
528
+ "traceback": [
529
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
530
+ "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
531
+ "Cell \u001b[0;32mIn [27], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mhuggingface_hub\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m notebook_login\n\u001b[1;32m 2\u001b[0m notebook_login()\n\u001b[0;32m----> 3\u001b[0m \u001b[43mtrainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mpush_to_hub\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
532
+ "File \u001b[0;32m/opt/homebrew/Caskroom/miniforge/base/envs/augmented_poetry/lib/python3.8/site-packages/transformers/trainer.py:2677\u001b[0m, in \u001b[0;36mTrainer.push_to_hub\u001b[0;34m(self, commit_message, blocking, **kwargs)\u001b[0m\n\u001b[1;32m 2674\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mis_world_process_zero():\n\u001b[1;32m 2675\u001b[0m \u001b[39mreturn\u001b[39;00m\n\u001b[0;32m-> 2677\u001b[0m git_head_commit_url \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mrepo\u001b[39m.\u001b[39mpush_to_hub(commit_message\u001b[39m=\u001b[39mcommit_message, blocking\u001b[39m=\u001b[39mblocking)\n\u001b[1;32m 2678\u001b[0m \u001b[39m# push separately the model card to be independant from the rest of the model\u001b[39;00m\n\u001b[1;32m 2679\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39margs\u001b[39m.\u001b[39mshould_save:\n",
533
+ "\u001b[0;31mAttributeError\u001b[0m: 'Trainer' object has no attribute 'repo'"
534
+ ]
535
+ }
536
+ ],
537
+ "source": [
538
+ "from huggingface_hub import notebook_login\n",
539
+ "notebook_login()\n",
540
+ "trainer.push_to_hub()"
541
  ]
542
  }
543
  ],