Question Answering
sanjudebnath commited on
Commit
4743e80
verified
1 Parent(s): 9515c2a

Upload 8 files

Browse files
Files changed (8) hide show
  1. application.py +70 -0
  2. distilbert.ipynb +981 -0
  3. distilbert.py +175 -0
  4. load_data.ipynb +1209 -0
  5. qa_model.py +532 -0
  6. question_answering.ipynb +2403 -0
  7. requirements.txt +168 -0
  8. util.py +134 -0
application.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import torch
4
+ from transformers import DistilBertTokenizer, DistilBertForMaskedLM
5
+
6
+ from qa_model import ReuseQuestionDistilBERT
7
+
8
+ @st.cache(allow_output_mutation=True)
9
+ def load_model():
10
+ mod = DistilBertForMaskedLM.from_pretrained("distilbert-base-uncased").distilbert
11
+ m = ReuseQuestionDistilBERT(mod)
12
+ m.load_state_dict(torch.load("distilbert_reuse.model", map_location=torch.device('cpu')))
13
+ model = m
14
+ del mod
15
+ del m
16
+ tokenizer = DistilBertTokenizer.from_pretrained('qa_tokenizer')
17
+ return model, tokenizer
18
+
19
+
20
+ def get_answer(question, text, tokenizer, model):
21
+ question = [question.strip()]
22
+ text = [text.strip()]
23
+
24
+ inputs = tokenizer(
25
+ question,
26
+ text,
27
+ max_length=512,
28
+ truncation="only_second",
29
+ padding="max_length",
30
+ )
31
+ input_ids = torch.tensor(inputs['input_ids'])
32
+ outputs = model(input_ids, attention_mask=torch.tensor(inputs['attention_mask']), start_positions=None, end_positions=None)
33
+
34
+ start = torch.argmax(outputs['start_logits'])
35
+ end = torch.argmax(outputs['end_logits'])
36
+
37
+ ans_tokens = input_ids[0][start: end + 1]
38
+
39
+ answer_tokens = tokenizer.convert_ids_to_tokens(ans_tokens, skip_special_tokens=True)
40
+ predicted = tokenizer.convert_tokens_to_string(answer_tokens)
41
+ return predicted
42
+
43
+
44
+ def main():
45
+ st.set_page_config(page_title="Question Answering Tool", page_icon=":mag_right:")
46
+
47
+ st.write("# Question Answering Tool \n"
48
+ "This tool will help you find answers to your questions about the text you provide. \n"
49
+ "Please enter your question and the text you want to search in the boxes below.")
50
+ model, tokenizer = load_model()
51
+
52
+ with st.form("qa_form"):
53
+ # define a streamlit textarea
54
+ text = st.text_area("Enter your text here", on_change=None)
55
+
56
+ # define a streamlit input
57
+ question = st.text_input("Enter your question here")
58
+
59
+ if st.form_submit_button("Submit"):
60
+ data_load_state = st.text('Let me think about that...')
61
+ # call the function to get the answer
62
+ answer = get_answer(question, text, tokenizer, model)
63
+ # display the answer
64
+ if answer == "":
65
+ data_load_state.text("Sorry but I don't know the answer to that question")
66
+ else:
67
+ data_load_state.text(answer)
68
+
69
+
70
+ main()
distilbert.ipynb ADDED
@@ -0,0 +1,981 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "47700837",
6
+ "metadata": {},
7
+ "source": [
8
+ "# DistilBERT Base Model\n",
9
+ "The following contains the code to create and train a DistilBERT model using the Huggingface library. It works quite well for a moderate amount of data, but the runtime increases quite drastically with data.\n",
10
+ "\n",
11
+ "I decided to take the pretrained model after all, still, creating the model myself was quite interesting!"
12
+ ]
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "execution_count": 4,
17
+ "id": "c09fa906",
18
+ "metadata": {},
19
+ "outputs": [],
20
+ "source": [
21
+ "from pathlib import Path\n",
22
+ "import torch\n",
23
+ "import time\n",
24
+ "from pathlib import Path\n",
25
+ "from transformers import DistilBertTokenizerFast\n",
26
+ "import os\n",
27
+ "from transformers import DistilBertConfig\n",
28
+ "from transformers import DistilBertForMaskedLM\n",
29
+ "from tokenizers import BertWordPieceTokenizer\n",
30
+ "from tqdm.auto import tqdm\n",
31
+ "from torch.optim import AdamW\n",
32
+ "import torchtest\n",
33
+ "from transformers import pipeline\n",
34
+ "\n",
35
+ "\n",
36
+ "from distilbert import test_model\n",
37
+ "from distilbert import Dataset\n",
38
+ "\n",
39
+ "import numpy as np"
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "markdown",
44
+ "id": "3b773fac",
45
+ "metadata": {},
46
+ "source": [
47
+ "## Tokeniser\n",
48
+ "We need a way to convert the strings we get as the input to numerical tokens, that we can give to the neual network. Hence, we take a BertWorkPieceTokenizer (works for DistilBERT too) and create tokens from our words."
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": 5,
54
+ "id": "24277c5b",
55
+ "metadata": {},
56
+ "outputs": [
57
+ {
58
+ "name": "stdout",
59
+ "output_type": "stream",
60
+ "text": [
61
+ "Tokeniser created\n"
62
+ ]
63
+ }
64
+ ],
65
+ "source": [
66
+ "fit_new_tokenizer = True\n",
67
+ "\n",
68
+ "if fit_new_tokenizer:\n",
69
+ " paths = [str(x) for x in Path('data/original').glob('**/*.txt')]\n",
70
+ "\n",
71
+ " tokenizer = BertWordPieceTokenizer(\n",
72
+ " clean_text=True,\n",
73
+ " handle_chinese_chars=False,\n",
74
+ " strip_accents=False,\n",
75
+ " lowercase=True\n",
76
+ " )\n",
77
+ " print(\"Tokeniser created\")"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "execution_count": 6,
83
+ "id": "beacf3e3",
84
+ "metadata": {},
85
+ "outputs": [
86
+ {
87
+ "name": "stdout",
88
+ "output_type": "stream",
89
+ "text": [
90
+ "\n",
91
+ "\n",
92
+ "\n"
93
+ ]
94
+ }
95
+ ],
96
+ "source": [
97
+ "# fit the tokenizer\n",
98
+ "if fit_new_tokenizer:\n",
99
+ " tokenizer.train(files=paths[:10], vocab_size=30_000, min_frequency=2,\n",
100
+ " limit_alphabet=1000, wordpieces_prefix='##',\n",
101
+ " special_tokens=['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]'])"
102
+ ]
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "execution_count": 7,
107
+ "id": "0d462cc5",
108
+ "metadata": {},
109
+ "outputs": [
110
+ {
111
+ "ename": "FileExistsError",
112
+ "evalue": "[Errno 17] File exists: './tokeniser'",
113
+ "output_type": "error",
114
+ "traceback": [
115
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
116
+ "\u001b[0;31mFileExistsError\u001b[0m Traceback (most recent call last)",
117
+ "Cell \u001b[0;32mIn [7], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m fit_new_tokenizer:\n\u001b[0;32m----> 2\u001b[0m os\u001b[38;5;241m.\u001b[39mmkdir(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m./tokeniser\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 3\u001b[0m tokenizer\u001b[38;5;241m.\u001b[39msave_model(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtokeniser\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTokeniser saved\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
118
+ "\u001b[0;31mFileExistsError\u001b[0m: [Errno 17] File exists: './tokeniser'"
119
+ ]
120
+ }
121
+ ],
122
+ "source": [
123
+ "if fit_new_tokenizer:\n",
124
+ " os.mkdir('./tokeniser')\n",
125
+ " tokenizer.save_model('tokeniser')\n",
126
+ " print(\"Tokeniser saved\")"
127
+ ]
128
+ },
129
+ {
130
+ "cell_type": "markdown",
131
+ "id": "7eaa1667",
132
+ "metadata": {},
133
+ "source": [
134
+ "After having created a basic tokeniser, we use the model to initialise a DistilBert tokenizer, that we need for the model architecture later on. We save the tokeniser separately."
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": 8,
140
+ "id": "f4dd0684",
141
+ "metadata": {},
142
+ "outputs": [
143
+ {
144
+ "data": {
145
+ "text/plain": [
146
+ "('distilbert_tokenizer/tokenizer_config.json',\n",
147
+ " 'distilbert_tokenizer/special_tokens_map.json',\n",
148
+ " 'distilbert_tokenizer/vocab.txt',\n",
149
+ " 'distilbert_tokenizer/added_tokens.json',\n",
150
+ " 'distilbert_tokenizer/tokenizer.json')"
151
+ ]
152
+ },
153
+ "execution_count": 8,
154
+ "metadata": {},
155
+ "output_type": "execute_result"
156
+ }
157
+ ],
158
+ "source": [
159
+ "tokenizer = DistilBertTokenizerFast.from_pretrained('tokeniser', max_len=512)\n",
160
+ "tokenizer.save_pretrained(\"distilbert_tokenizer\")"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "markdown",
165
+ "id": "bfcafcde",
166
+ "metadata": {},
167
+ "source": [
168
+ "### Testing\n",
169
+ "We now test the created tokenizer. We take a simple example and tokenise the input. It can be seen that we add a special token in the beginning and end ('CLS' and 'SEP'), which is how the BERT model was defined.\n",
170
+ "\n",
171
+ "When we translate the input back, we can see that we get the same, except for the first and last token. Also, we can see that questionmarks and commas are encoded separately."
172
+ ]
173
+ },
174
+ {
175
+ "cell_type": "code",
176
+ "execution_count": 9,
177
+ "id": "37e7f6a8",
178
+ "metadata": {},
179
+ "outputs": [
180
+ {
181
+ "name": "stdout",
182
+ "output_type": "stream",
183
+ "text": [
184
+ "{'input_ids': [2, 10958, 16, 2175, 1993, 1965, 35, 3], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}\n"
185
+ ]
186
+ }
187
+ ],
188
+ "source": [
189
+ "tokens = tokenizer('Hello, how are you?')\n",
190
+ "print(tokens)"
191
+ ]
192
+ },
193
+ {
194
+ "cell_type": "code",
195
+ "execution_count": 10,
196
+ "id": "bbd0c4b1",
197
+ "metadata": {},
198
+ "outputs": [
199
+ {
200
+ "data": {
201
+ "text/plain": [
202
+ "'[CLS] hello, how are you? [SEP]'"
203
+ ]
204
+ },
205
+ "execution_count": 10,
206
+ "metadata": {},
207
+ "output_type": "execute_result"
208
+ }
209
+ ],
210
+ "source": [
211
+ "tokenizer.decode(tokens['input_ids'])"
212
+ ]
213
+ },
214
+ {
215
+ "cell_type": "code",
216
+ "execution_count": 11,
217
+ "id": "4ab6e506",
218
+ "metadata": {},
219
+ "outputs": [
220
+ {
221
+ "name": "stdout",
222
+ "output_type": "stream",
223
+ "text": [
224
+ "[CLS]\n",
225
+ "hello\n",
226
+ ",\n",
227
+ "how\n",
228
+ "are\n",
229
+ "you\n",
230
+ "?\n",
231
+ "[SEP]\n"
232
+ ]
233
+ }
234
+ ],
235
+ "source": [
236
+ "for tok in tokens['input_ids']:\n",
237
+ " print(tokenizer.decode(tok))"
238
+ ]
239
+ },
240
+ {
241
+ "cell_type": "code",
242
+ "execution_count": 12,
243
+ "id": "c75d3255",
244
+ "metadata": {},
245
+ "outputs": [],
246
+ "source": [
247
+ "assert len(tokenizer.vocab) == 30_000"
248
+ ]
249
+ },
250
+ {
251
+ "cell_type": "markdown",
252
+ "id": "dd114355",
253
+ "metadata": {},
254
+ "source": [
255
+ "## Dataset\n",
256
+ "We now define a function to mask some of the tokens. In particular, we create a Dataset class, that automates loading the data and tokenising it for us. Lastly, we use a DataLoader to load the data step by step into memory.\n",
257
+ "\n",
258
+ "The big problem with the limited resources we have is memory. In particular, I am loading the data sequentially, file by file, keeping track how many samples have been read. Shuffling wouldn't work here (it would also not make a lot of sense for this dataset)."
259
+ ]
260
+ },
261
+ {
262
+ "cell_type": "code",
263
+ "execution_count": 10,
264
+ "id": "bff9ea54",
265
+ "metadata": {},
266
+ "outputs": [],
267
+ "source": [
268
+ "# create dataset and dataloader \n",
269
+ "dataset = Dataset(paths = [str(x) for x in Path('data/original').glob('**/*.txt')][50:70], tokenizer=tokenizer)\n",
270
+ "loader = torch.utils.data.DataLoader(dataset, batch_size=8)\n",
271
+ "\n",
272
+ "test_dataset = Dataset(paths = [str(x) for x in Path('data/original').glob('**/*.txt')][10:12], tokenizer=tokenizer)\n",
273
+ "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4)"
274
+ ]
275
+ },
276
+ {
277
+ "cell_type": "markdown",
278
+ "id": "6bbe6e63",
279
+ "metadata": {},
280
+ "source": [
281
+ "### Testing\n",
282
+ "The randomisation makes it a bit difficult to test. But altogether, we see that the input ids, masks and labels have the same shape. Also, as we mask 15% of the samples, when decoding a given sample, we can see that some samples are now '[MASK]'."
283
+ ]
284
+ },
285
+ {
286
+ "cell_type": "code",
287
+ "execution_count": 11,
288
+ "id": "436ab745",
289
+ "metadata": {},
290
+ "outputs": [],
291
+ "source": [
292
+ "i = iter(dataset)"
293
+ ]
294
+ },
295
+ {
296
+ "cell_type": "code",
297
+ "execution_count": 12,
298
+ "id": "330e599d",
299
+ "metadata": {},
300
+ "outputs": [
301
+ {
302
+ "name": "stdout",
303
+ "output_type": "stream",
304
+ "text": [
305
+ "Passed\n"
306
+ ]
307
+ }
308
+ ],
309
+ "source": [
310
+ "for j in range(10):\n",
311
+ " sample = next(i)\n",
312
+ " \n",
313
+ " input_ids = sample['input_ids']\n",
314
+ " attention_masks = sample['attention_mask']\n",
315
+ " labels = sample['labels']\n",
316
+ " \n",
317
+ " # check if the dimensions are right\n",
318
+ " assert input_ids.shape[0] == (512)\n",
319
+ " assert attention_masks.shape[0] == (512)\n",
320
+ " assert labels.shape[0] == (512)\n",
321
+ " \n",
322
+ " # if the input ids are not masked, the labels are the same as the input ids\n",
323
+ " assert np.array_equal(input_ids[input_ids != 4].numpy(),labels[input_ids != 4].numpy())\n",
324
+ " # input ids are zero if the attention masks are zero\n",
325
+ " assert np.all(input_ids[attention_masks == 0].numpy()==0)\n",
326
+ " # check if input contains masked tokens (we can't guarantee this 100% but this will apply) most likely\n",
327
+ " assert np.any(input_ids.numpy() == 4)\n",
328
+ "print(\"Passed\")"
329
+ ]
330
+ },
331
+ {
332
+ "cell_type": "markdown",
333
+ "id": "08db6d22",
334
+ "metadata": {},
335
+ "source": [
336
+ "## Model\n",
337
+ "In the following section, we intialise and train a model."
338
+ ]
339
+ },
340
+ {
341
+ "cell_type": "code",
342
+ "execution_count": 13,
343
+ "id": "7803bda6",
344
+ "metadata": {},
345
+ "outputs": [],
346
+ "source": [
347
+ "config = DistilBertConfig(\n",
348
+ " vocab_size=30000,\n",
349
+ " max_position_embeddings=514\n",
350
+ ")"
351
+ ]
352
+ },
353
+ {
354
+ "cell_type": "code",
355
+ "execution_count": 14,
356
+ "id": "8ca03f6a",
357
+ "metadata": {},
358
+ "outputs": [],
359
+ "source": [
360
+ "model = DistilBertForMaskedLM(config)"
361
+ ]
362
+ },
363
+ {
364
+ "cell_type": "code",
365
+ "execution_count": 15,
366
+ "id": "4da22bff",
367
+ "metadata": {
368
+ "scrolled": false
369
+ },
370
+ "outputs": [
371
+ {
372
+ "name": "stderr",
373
+ "output_type": "stream",
374
+ "text": [
375
+ "/home/sanju/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/cuda/__init__.py:83: UserWarning: CUDA initialization: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero. (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:109.)\n",
376
+ " return torch._C._cuda_getDeviceCount() > 0\n"
377
+ ]
378
+ },
379
+ {
380
+ "data": {
381
+ "text/plain": [
382
+ "DistilBertForMaskedLM(\n",
383
+ " (activation): GELUActivation()\n",
384
+ " (distilbert): DistilBertModel(\n",
385
+ " (embeddings): Embeddings(\n",
386
+ " (word_embeddings): Embedding(30000, 768, padding_idx=0)\n",
387
+ " (position_embeddings): Embedding(514, 768)\n",
388
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
389
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
390
+ " )\n",
391
+ " (transformer): Transformer(\n",
392
+ " (layer): ModuleList(\n",
393
+ " (0): TransformerBlock(\n",
394
+ " (attention): MultiHeadSelfAttention(\n",
395
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
396
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
397
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
398
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
399
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
400
+ " )\n",
401
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
402
+ " (ffn): FFN(\n",
403
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
404
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
405
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
406
+ " (activation): GELUActivation()\n",
407
+ " )\n",
408
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
409
+ " )\n",
410
+ " (1): TransformerBlock(\n",
411
+ " (attention): MultiHeadSelfAttention(\n",
412
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
413
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
414
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
415
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
416
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
417
+ " )\n",
418
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
419
+ " (ffn): FFN(\n",
420
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
421
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
422
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
423
+ " (activation): GELUActivation()\n",
424
+ " )\n",
425
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
426
+ " )\n",
427
+ " (2): TransformerBlock(\n",
428
+ " (attention): MultiHeadSelfAttention(\n",
429
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
430
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
431
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
432
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
433
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
434
+ " )\n",
435
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
436
+ " (ffn): FFN(\n",
437
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
438
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
439
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
440
+ " (activation): GELUActivation()\n",
441
+ " )\n",
442
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
443
+ " )\n",
444
+ " (3): TransformerBlock(\n",
445
+ " (attention): MultiHeadSelfAttention(\n",
446
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
447
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
448
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
449
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
450
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
451
+ " )\n",
452
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
453
+ " (ffn): FFN(\n",
454
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
455
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
456
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
457
+ " (activation): GELUActivation()\n",
458
+ " )\n",
459
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
460
+ " )\n",
461
+ " (4): TransformerBlock(\n",
462
+ " (attention): MultiHeadSelfAttention(\n",
463
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
464
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
465
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
466
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
467
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
468
+ " )\n",
469
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
470
+ " (ffn): FFN(\n",
471
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
472
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
473
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
474
+ " (activation): GELUActivation()\n",
475
+ " )\n",
476
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
477
+ " )\n",
478
+ " (5): TransformerBlock(\n",
479
+ " (attention): MultiHeadSelfAttention(\n",
480
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
481
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
482
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
483
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
484
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
485
+ " )\n",
486
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
487
+ " (ffn): FFN(\n",
488
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
489
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
490
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
491
+ " (activation): GELUActivation()\n",
492
+ " )\n",
493
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
494
+ " )\n",
495
+ " )\n",
496
+ " )\n",
497
+ " )\n",
498
+ " (vocab_transform): Linear(in_features=768, out_features=768, bias=True)\n",
499
+ " (vocab_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
500
+ " (vocab_projector): Linear(in_features=768, out_features=30000, bias=True)\n",
501
+ " (mlm_loss_fct): CrossEntropyLoss()\n",
502
+ ")"
503
+ ]
504
+ },
505
+ "execution_count": 15,
506
+ "metadata": {},
507
+ "output_type": "execute_result"
508
+ }
509
+ ],
510
+ "source": [
511
+ "# if we have a GPU - train on gpu\n",
512
+ "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
513
+ "model.to(device)"
514
+ ]
515
+ },
516
+ {
517
+ "cell_type": "markdown",
518
+ "id": "6fb8c2e2",
519
+ "metadata": {},
520
+ "source": [
521
+ "### Testing the model\n",
522
+ "I stumbled across some Medium articles on how to test DeepLearning models beforehand \n",
523
+ "* https://thenerdstation.medium.com/how-to-unit-test-machine-learning-code-57cf6fd81765: the package is however deprecated\n",
524
+ "* https://towardsdatascience.com/testing-your-pytorch-models-with-torcheck-cb689ecbc08c: released a package (torcheck)\n",
525
+ "* https://github.com/suriyadeepan/torchtest: I found this package, which is the PyTorch version of the first one and is still maintained.\n",
526
+ "\n",
527
+ "Essentially, testing a model is inherently difficult, because we do not know the result in the beginning. Still, the following four conditions should be satisfied in every model (see second reference above):\n",
528
+ "1. The parameters should change during training (if they are not frozen).\n",
529
+ "2. The parameters should not change if they are frozen.\n",
530
+ "3. The range of the ouput should be in a predefined range.\n",
531
+ "4. The parameters should never contain NaN. The same goes for the outputs too.\n",
532
+ "\n",
533
+ "I tried using the packages, but they do not trivially apply for models with multiple inputs (we have input ids and attention masks). The following is partly adapted from the torchtest package (https://github.com/suriyadeepan/torchtest/blob/master/torchtest/torchtest.py)."
534
+ ]
535
+ },
536
+ {
537
+ "cell_type": "code",
538
+ "execution_count": 16,
539
+ "id": "cfd33fa1",
540
+ "metadata": {},
541
+ "outputs": [],
542
+ "source": [
543
+ "# get smaller dataset\n",
544
+ "test_ds = Dataset(paths = [str(x) for x in Path('data/original').glob('**/*.txt')][:2], tokenizer=tokenizer)\n",
545
+ "test_ds_loader = torch.utils.data.DataLoader(test_ds, batch_size=2)\n",
546
+ "optim=torch.optim.Adam(model.parameters())"
547
+ ]
548
+ },
549
+ {
550
+ "cell_type": "code",
551
+ "execution_count": 17,
552
+ "id": "907db815",
553
+ "metadata": {},
554
+ "outputs": [
555
+ {
556
+ "name": "stdout",
557
+ "output_type": "stream",
558
+ "text": [
559
+ "Passed\n"
560
+ ]
561
+ }
562
+ ],
563
+ "source": [
564
+ "from distilbert import test_model\n",
565
+ "\n",
566
+ "test_model(model, optim, test_ds_loader, device)"
567
+ ]
568
+ },
569
+ {
570
+ "cell_type": "markdown",
571
+ "id": "c02c9c4b",
572
+ "metadata": {},
573
+ "source": [
574
+ "### Training the model\n",
575
+ "We use AdamW as the optimiser and train for 10 epochs.\n",
576
+ "\n",
577
+ "Taking the whole dataset, takes about 100 hours per epoch for me, so I wasn't able to do that."
578
+ ]
579
+ },
580
+ {
581
+ "cell_type": "code",
582
+ "execution_count": 18,
583
+ "id": "178914f8",
584
+ "metadata": {},
585
+ "outputs": [
586
+ {
587
+ "data": {
588
+ "text/plain": [
589
+ "DistilBertForMaskedLM(\n",
590
+ " (activation): GELUActivation()\n",
591
+ " (distilbert): DistilBertModel(\n",
592
+ " (embeddings): Embeddings(\n",
593
+ " (word_embeddings): Embedding(30000, 768, padding_idx=0)\n",
594
+ " (position_embeddings): Embedding(514, 768)\n",
595
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
596
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
597
+ " )\n",
598
+ " (transformer): Transformer(\n",
599
+ " (layer): ModuleList(\n",
600
+ " (0): TransformerBlock(\n",
601
+ " (attention): MultiHeadSelfAttention(\n",
602
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
603
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
604
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
605
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
606
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
607
+ " )\n",
608
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
609
+ " (ffn): FFN(\n",
610
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
611
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
612
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
613
+ " (activation): GELUActivation()\n",
614
+ " )\n",
615
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
616
+ " )\n",
617
+ " (1): TransformerBlock(\n",
618
+ " (attention): MultiHeadSelfAttention(\n",
619
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
620
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
621
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
622
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
623
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
624
+ " )\n",
625
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
626
+ " (ffn): FFN(\n",
627
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
628
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
629
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
630
+ " (activation): GELUActivation()\n",
631
+ " )\n",
632
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
633
+ " )\n",
634
+ " (2): TransformerBlock(\n",
635
+ " (attention): MultiHeadSelfAttention(\n",
636
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
637
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
638
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
639
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
640
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
641
+ " )\n",
642
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
643
+ " (ffn): FFN(\n",
644
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
645
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
646
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
647
+ " (activation): GELUActivation()\n",
648
+ " )\n",
649
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
650
+ " )\n",
651
+ " (3): TransformerBlock(\n",
652
+ " (attention): MultiHeadSelfAttention(\n",
653
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
654
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
655
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
656
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
657
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
658
+ " )\n",
659
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
660
+ " (ffn): FFN(\n",
661
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
662
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
663
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
664
+ " (activation): GELUActivation()\n",
665
+ " )\n",
666
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
667
+ " )\n",
668
+ " (4): TransformerBlock(\n",
669
+ " (attention): MultiHeadSelfAttention(\n",
670
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
671
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
672
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
673
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
674
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
675
+ " )\n",
676
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
677
+ " (ffn): FFN(\n",
678
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
679
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
680
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
681
+ " (activation): GELUActivation()\n",
682
+ " )\n",
683
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
684
+ " )\n",
685
+ " (5): TransformerBlock(\n",
686
+ " (attention): MultiHeadSelfAttention(\n",
687
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
688
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
689
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
690
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
691
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
692
+ " )\n",
693
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
694
+ " (ffn): FFN(\n",
695
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
696
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
697
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
698
+ " (activation): GELUActivation()\n",
699
+ " )\n",
700
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
701
+ " )\n",
702
+ " )\n",
703
+ " )\n",
704
+ " )\n",
705
+ " (vocab_transform): Linear(in_features=768, out_features=768, bias=True)\n",
706
+ " (vocab_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
707
+ " (vocab_projector): Linear(in_features=768, out_features=30000, bias=True)\n",
708
+ " (mlm_loss_fct): CrossEntropyLoss()\n",
709
+ ")"
710
+ ]
711
+ },
712
+ "execution_count": 18,
713
+ "metadata": {},
714
+ "output_type": "execute_result"
715
+ }
716
+ ],
717
+ "source": [
718
+ "model = DistilBertForMaskedLM(config)\n",
719
+ "# if we have a GPU - train on gpu\n",
720
+ "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
721
+ "model.to(device)"
722
+ ]
723
+ },
724
+ {
725
+ "cell_type": "code",
726
+ "execution_count": 19,
727
+ "id": "bb6532be",
728
+ "metadata": {},
729
+ "outputs": [],
730
+ "source": [
731
+ "# we use AdamW as the optimiser\n",
732
+ "optim = AdamW(model.parameters(), lr=1e-4)"
733
+ ]
734
+ },
735
+ {
736
+ "cell_type": "code",
737
+ "execution_count": 20,
738
+ "id": "2fd5d609",
739
+ "metadata": {},
740
+ "outputs": [
741
+ {
742
+ "data": {
743
+ "application/vnd.jupyter.widget-view+json": {
744
+ "model_id": "c3386dc78c65490a96d11ade635d522f",
745
+ "version_major": 2,
746
+ "version_minor": 0
747
+ },
748
+ "text/plain": [
749
+ " 0%| | 0/23750 [00:00<?, ?it/s]"
750
+ ]
751
+ },
752
+ "metadata": {},
753
+ "output_type": "display_data"
754
+ }
755
+ ],
756
+ "source": [
757
+ "epochs = 10\n",
758
+ "\n",
759
+ "for epoch in range(epochs):\n",
760
+ " loop = tqdm(loader, leave=True)\n",
761
+ " \n",
762
+ " # set model to training mode\n",
763
+ " model.train()\n",
764
+ " losses = []\n",
765
+ " \n",
766
+ " # iterate over dataset\n",
767
+ " for batch in loop:\n",
768
+ " optim.zero_grad()\n",
769
+ " \n",
770
+ " # copy input to device\n",
771
+ " input_ids = batch['input_ids'].to(device)\n",
772
+ " attention_mask = batch['attention_mask'].to(device)\n",
773
+ " labels = batch['labels'].to(device)\n",
774
+ " \n",
775
+ " # predict\n",
776
+ " outputs = model(input_ids, attention_mask=attention_mask, labels=labels)\n",
777
+ " \n",
778
+ " # update weights\n",
779
+ " loss = outputs.loss\n",
780
+ " loss.backward()\n",
781
+ " \n",
782
+ " optim.step()\n",
783
+ " \n",
784
+ " # output current loss\n",
785
+ " loop.set_description(f'Epoch {epoch}')\n",
786
+ " loop.set_postfix(loss=loss.item())\n",
787
+ " losses.append(loss.item())\n",
788
+ " \n",
789
+ " del input_ids\n",
790
+ " del attention_mask\n",
791
+ " del labels\n",
792
+ " \n",
793
+ " print(\"Mean Training Loss\", np.mean(losses))\n",
794
+ " losses = []\n",
795
+ " loop = tqdm(test_loader, leave=True)\n",
796
+ " \n",
797
+ " # set model to evaluation mode\n",
798
+ " model.eval()\n",
799
+ " \n",
800
+ " # iterate over dataset\n",
801
+ " for batch in loop:\n",
802
+ " # copy input to device\n",
803
+ " input_ids = batch['input_ids'].to(device)\n",
804
+ " attention_mask = batch['attention_mask'].to(device)\n",
805
+ " labels = batch['labels'].to(device)\n",
806
+ " \n",
807
+ " # predict\n",
808
+ " outputs = model(input_ids, attention_mask=attention_mask, labels=labels)\n",
809
+ " \n",
810
+ " # update weights\n",
811
+ " loss = outputs.loss\n",
812
+ " \n",
813
+ " # output current loss\n",
814
+ " loop.set_description(f'Epoch {epoch}')\n",
815
+ " loop.set_postfix(loss=loss.item())\n",
816
+ " losses.append(loss.item())\n",
817
+ " \n",
818
+ " del input_ids\n",
819
+ " del attention_mask\n",
820
+ " del labels\n",
821
+ " print(\"Mean Test Loss\", np.mean(losses))"
822
+ ]
823
+ },
824
+ {
825
+ "cell_type": "code",
826
+ "execution_count": 22,
827
+ "id": "03c23c3e",
828
+ "metadata": {},
829
+ "outputs": [],
830
+ "source": [
831
+ "# save the pretrained model\n",
832
+ "torch.save(model, \"distilbert.model\")"
833
+ ]
834
+ },
835
+ {
836
+ "cell_type": "code",
837
+ "execution_count": 25,
838
+ "id": "9b18d3e3",
839
+ "metadata": {},
840
+ "outputs": [],
841
+ "source": [
842
+ "model = torch.load(\"distilbert.model\")"
843
+ ]
844
+ },
845
+ {
846
+ "cell_type": "markdown",
847
+ "id": "e6ad94db",
848
+ "metadata": {},
849
+ "source": [
850
+ "### Testing\n",
851
+ "Huggingface provides a library to quickly be able to see what word the model would predict for our masked token."
852
+ ]
853
+ },
854
+ {
855
+ "cell_type": "code",
856
+ "execution_count": 27,
857
+ "id": "7c8582d2",
858
+ "metadata": {},
859
+ "outputs": [],
860
+ "source": [
861
+ "fill = pipeline(\"fill-mask\", model='distilbert', config=config, tokenizer='distilbert_tokenizer')"
862
+ ]
863
+ },
864
+ {
865
+ "cell_type": "code",
866
+ "execution_count": 28,
867
+ "id": "d309e57f",
868
+ "metadata": {},
869
+ "outputs": [
870
+ {
871
+ "data": {
872
+ "text/plain": [
873
+ "[{'score': 0.19730663299560547,\n",
874
+ " 'token': 2965,\n",
875
+ " 'token_str': 'change',\n",
876
+ " 'sequence': 'it seems important to tackle the climate change.'},\n",
877
+ " {'score': 0.12946806848049164,\n",
878
+ " 'token': 5215,\n",
879
+ " 'token_str': 'crisis',\n",
880
+ " 'sequence': 'it seems important to tackle the climate crisis.'},\n",
881
+ " {'score': 0.05868387222290039,\n",
882
+ " 'token': 3688,\n",
883
+ " 'token_str': 'issues',\n",
884
+ " 'sequence': 'it seems important to tackle the climate issues.'},\n",
885
+ " {'score': 0.047418754547834396,\n",
886
+ " 'token': 3406,\n",
887
+ " 'token_str': 'issue',\n",
888
+ " 'sequence': 'it seems important to tackle the climate issue.'},\n",
889
+ " {'score': 0.027855267748236656,\n",
890
+ " 'token': 2629,\n",
891
+ " 'token_str': 'here',\n",
892
+ " 'sequence': 'it seems important to tackle the climate here.'}]"
893
+ ]
894
+ },
895
+ "execution_count": 28,
896
+ "metadata": {},
897
+ "output_type": "execute_result"
898
+ }
899
+ ],
900
+ "source": [
901
+ "fill(f'It seems important to tackle the climate {fill.tokenizer.mask_token}.')"
902
+ ]
903
+ },
904
+ {
905
+ "cell_type": "code",
906
+ "execution_count": null,
907
+ "id": "94e3e623",
908
+ "metadata": {},
909
+ "outputs": [],
910
+ "source": []
911
+ }
912
+ ],
913
+ "metadata": {
914
+ "kernelspec": {
915
+ "display_name": "Python 3.10.8 ('venv': venv)",
916
+ "language": "python",
917
+ "name": "python3"
918
+ },
919
+ "language_info": {
920
+ "codemirror_mode": {
921
+ "name": "ipython",
922
+ "version": 3
923
+ },
924
+ "file_extension": ".py",
925
+ "mimetype": "text/x-python",
926
+ "name": "python",
927
+ "nbconvert_exporter": "python",
928
+ "pygments_lexer": "ipython3",
929
+ "version": "3.10.16"
930
+ },
931
+ "toc": {
932
+ "base_numbering": 1,
933
+ "nav_menu": {},
934
+ "number_sections": true,
935
+ "sideBar": true,
936
+ "skip_h1_title": false,
937
+ "title_cell": "Table of Contents",
938
+ "title_sidebar": "Contents",
939
+ "toc_cell": false,
940
+ "toc_position": {},
941
+ "toc_section_display": true,
942
+ "toc_window_display": false
943
+ },
944
+ "varInspector": {
945
+ "cols": {
946
+ "lenName": 16,
947
+ "lenType": 16,
948
+ "lenVar": 40
949
+ },
950
+ "kernels_config": {
951
+ "python": {
952
+ "delete_cmd_postfix": "",
953
+ "delete_cmd_prefix": "del ",
954
+ "library": "var_list.py",
955
+ "varRefreshCmd": "print(var_dic_list())"
956
+ },
957
+ "r": {
958
+ "delete_cmd_postfix": ") ",
959
+ "delete_cmd_prefix": "rm(",
960
+ "library": "var_list.r",
961
+ "varRefreshCmd": "cat(var_dic_list()) "
962
+ }
963
+ },
964
+ "types_to_exclude": [
965
+ "module",
966
+ "function",
967
+ "builtin_function_or_method",
968
+ "instance",
969
+ "_Feature"
970
+ ],
971
+ "window_display": false
972
+ },
973
+ "vscode": {
974
+ "interpreter": {
975
+ "hash": "85bf9c14e9ba73b783ed1274d522bec79eb0b2b739090180d8ce17bb11aff4aa"
976
+ }
977
+ }
978
+ },
979
+ "nbformat": 4,
980
+ "nbformat_minor": 5
981
+ }
distilbert.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class Dataset(torch.utils.data.Dataset):
4
+ """
5
+ This class loads and preprocesses the given text data
6
+ """
7
+ def __init__(self, paths, tokenizer):
8
+ """
9
+ This function initialises the object. It takes the given paths and tokeniser.
10
+ """
11
+ # the last file might not have 10000 samples, which makes it difficult to get the total length of the ds
12
+ self.paths = paths[:len(paths)-1]
13
+ self.tokenizer = tokenizer
14
+ self.data = self.read_file(self.paths[0])
15
+ self.current_file = 1
16
+ self.remaining = len(self.data)
17
+ self.encodings = self.get_encodings(self.data)
18
+
19
+ def __len__(self):
20
+ """
21
+ returns the lenght of the ds
22
+ """
23
+ return 10000*len(self.paths)
24
+
25
+ def read_file(self, path):
26
+ """
27
+ reads a given file
28
+ """
29
+ with open(path, 'r', encoding='utf-8') as f:
30
+ lines = f.read().split('\n')
31
+ return lines
32
+
33
+ def get_encodings(self, lines_all):
34
+ """
35
+ Creates encodings for a given text input
36
+ """
37
+ # tokenise all text
38
+ batch = self.tokenizer(lines_all, max_length=512, padding='max_length', truncation=True)
39
+
40
+ # Ground Truth
41
+ labels = torch.tensor(batch['input_ids'])
42
+ # Attention Masks
43
+ mask = torch.tensor(batch['attention_mask'])
44
+
45
+ # Input to be masked
46
+ input_ids = labels.detach().clone()
47
+ rand = torch.rand(input_ids.shape)
48
+
49
+ # with a probability of 15%, mask a given word, leave out CLS, SEP and PAD
50
+ mask_arr = (rand < .15) * (input_ids != 0) * (input_ids != 2) * (input_ids != 3)
51
+ # assign token 4 (=MASK)
52
+ input_ids[mask_arr] = 4
53
+
54
+ return {'input_ids':input_ids, 'attention_mask':mask, 'labels':labels}
55
+
56
+ def __getitem__(self, i):
57
+ """
58
+ returns item i
59
+ Note: do not use shuffling for this dataset
60
+ """
61
+ # if we have looked at all items in the file - take next
62
+ if self.remaining == 0:
63
+ self.data = self.read_file(self.paths[self.current_file])
64
+ self.current_file += 1
65
+ self.remaining = len(self.data)
66
+ self.encodings = self.get_encodings(self.data)
67
+
68
+ # if we are at the end of the dataset, start over again
69
+ if self.current_file == len(self.paths):
70
+ self.current_file = 0
71
+
72
+ self.remaining -= 1
73
+ return {key: tensor[i%10000] for key, tensor in self.encodings.items()}
74
+
75
+ def test_model(model, optim, test_ds_loader, device):
76
+ """
77
+ This function tests whether the parameters of the model that are frozen change, the ones that are not frozen do change,
78
+ and whether any parameters become NaN or Inf
79
+ :param model: model to be tested
80
+ :param optim: optimiser used for training
81
+ :param test_ds_loader: dataset to perform the forward pass on
82
+ :param device: current device
83
+ :raises Exception: if any of the above conditions are not met
84
+ """
85
+ ## Check if non-frozen parameters changed and frozen ones did not
86
+
87
+ # get initial parameters to check against
88
+ params = [ np for np in model.named_parameters() if np[1].requires_grad ]
89
+ initial_params = [ (name, p.clone()) for (name, p) in params ]
90
+
91
+ params_frozen = [ np for np in model.named_parameters() if not np[1].requires_grad ]
92
+ initial_params_frozen = [ (name, p.clone()) for (name, p) in params_frozen ]
93
+
94
+ optim.zero_grad()
95
+
96
+ # get data
97
+ batch = next(iter(test_ds_loader))
98
+
99
+ input_ids = batch['input_ids'].to(device)
100
+ attention_mask = batch['attention_mask'].to(device)
101
+ labels = batch['labels'].to(device)
102
+
103
+ # forward pass and backpropagation
104
+ outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
105
+ loss = outputs.loss
106
+ loss.backward()
107
+ optim.step()
108
+
109
+ # check if variables have changed
110
+ for (_, p0), (name, p1) in zip(initial_params, params):
111
+ # check different than initial
112
+ try:
113
+ assert not torch.equal(p0.to(device), p1.to(device))
114
+ except AssertionError:
115
+ raise Exception(
116
+ "{var_name} {msg}".format(
117
+ var_name=name,
118
+ msg='did not change!'
119
+ )
120
+ )
121
+ # check not NaN
122
+ try:
123
+ assert not torch.isnan(p1).byte().any()
124
+ except AssertionError:
125
+ raise Exception(
126
+ "{var_name} {msg}".format(
127
+ var_name=name,
128
+ msg='is NaN!'
129
+ )
130
+ )
131
+ # check finite
132
+ try:
133
+ assert torch.isfinite(p1).byte().all()
134
+ except AssertionError:
135
+ raise Exception(
136
+ "{var_name} {msg}".format(
137
+ var_name=name,
138
+ msg='is Inf!'
139
+ )
140
+ )
141
+
142
+ # check that frozen weights have not changed
143
+ for (_, p0), (name, p1) in zip(initial_params_frozen, params_frozen):
144
+ # should be the same
145
+ try:
146
+ assert torch.equal(p0.to(device), p1.to(device))
147
+ except AssertionError:
148
+ raise Exception(
149
+ "{var_name} {msg}".format(
150
+ var_name=name,
151
+ msg='changed!'
152
+ )
153
+ )
154
+ # check not NaN
155
+ try:
156
+ assert not torch.isnan(p1).byte().any()
157
+ except AssertionError:
158
+ raise Exception(
159
+ "{var_name} {msg}".format(
160
+ var_name=name,
161
+ msg='is NaN!'
162
+ )
163
+ )
164
+
165
+ # check finite numbers
166
+ try:
167
+ assert torch.isfinite(p1).byte().all()
168
+ except AssertionError:
169
+ raise Exception(
170
+ "{var_name} {msg}".format(
171
+ var_name=name,
172
+ msg='is Inf!'
173
+ )
174
+ )
175
+ print("Passed")
load_data.ipynb ADDED
@@ -0,0 +1,1209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "12d87b30",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Load Data\n",
9
+ "This notebook loads and preproceses all necessary data, namely the following.\n",
10
+ "* OpenWebTextCorpus: for base DistilBERT model\n",
11
+ "* SQuAD datasrt: for Q&A\n",
12
+ "* Natural Questions (needs to be downloaded externally but is preprocessed here): for Q&A\n",
13
+ "* HotPotQA: for Q&A"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": 4,
19
+ "id": "7c82d7fa",
20
+ "metadata": {},
21
+ "outputs": [],
22
+ "source": [
23
+ "from tqdm.auto import tqdm\n",
24
+ "from datasets import load_dataset\n",
25
+ "import os\n",
26
+ "import pandas as pd\n",
27
+ "import random"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "markdown",
32
+ "id": "1737f219",
33
+ "metadata": {},
34
+ "source": [
35
+ "## Distilbert Data\n",
36
+ "In the following, we download the english openwebtext dataset from huggingface (https://huggingface.co/datasets/openwebtext). The dataset is provided by Aaron Gokaslan and Vanya Cohen from Brown University (https://skylion007.github.io/OpenWebTextCorpus/).\n",
37
+ "\n",
38
+ "We first load the data, investigate the structure and write the dataset into files of each 10 000 texts."
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "id": "cce7623c",
45
+ "metadata": {},
46
+ "outputs": [],
47
+ "source": [
48
+ "ds = load_dataset(\"openwebtext\")"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": 4,
54
+ "id": "678a5e86",
55
+ "metadata": {},
56
+ "outputs": [
57
+ {
58
+ "data": {
59
+ "text/plain": [
60
+ "DatasetDict({\n",
61
+ " train: Dataset({\n",
62
+ " features: ['text'],\n",
63
+ " num_rows: 8013769\n",
64
+ " })\n",
65
+ "})"
66
+ ]
67
+ },
68
+ "execution_count": 4,
69
+ "metadata": {},
70
+ "output_type": "execute_result"
71
+ }
72
+ ],
73
+ "source": [
74
+ "# we have a text-only training dataset with 8 million entries\n",
75
+ "ds"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "execution_count": 5,
81
+ "id": "b141bce7",
82
+ "metadata": {},
83
+ "outputs": [],
84
+ "source": [
85
+ "# create necessary folders\n",
86
+ "os.mkdir('data')\n",
87
+ "os.mkdir('data/original')"
88
+ ]
89
+ },
90
+ {
91
+ "cell_type": "code",
92
+ "execution_count": null,
93
+ "id": "ca94f995",
94
+ "metadata": {},
95
+ "outputs": [],
96
+ "source": [
97
+ "# save text in chunks of 10000 samples\n",
98
+ "text = []\n",
99
+ "i = 0\n",
100
+ "\n",
101
+ "for sample in tqdm(ds['train']):\n",
102
+ " # replace all newlines\n",
103
+ " sample = sample['text'].replace('\\n','')\n",
104
+ " \n",
105
+ " # append cleaned sample to all texts\n",
106
+ " text.append(sample)\n",
107
+ " \n",
108
+ " # if we processed 10000 samples, write them to a file and start over\n",
109
+ " if len(text) == 10000:\n",
110
+ " with open(f\"data/original/text_{i}.txt\", 'w', encoding='utf-8') as f:\n",
111
+ " f.write('\\n'.join(text))\n",
112
+ " text = []\n",
113
+ " i += 1 \n",
114
+ "\n",
115
+ "# write remaining samples to a file\n",
116
+ "with open(f\"data/original/text_{i}.txt\", 'w', encoding='utf-8') as f:\n",
117
+ " f.write('\\n'.join(text))"
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "markdown",
122
+ "id": "f131dcfc",
123
+ "metadata": {},
124
+ "source": [
125
+ "### Testing\n",
126
+ "If we load the first file, we should get a file that is 10000 lines long and has one column\n",
127
+ "\n",
128
+ "As we do not preprocess the data in any way, but just write the read text into the file, this is all testing necessary"
129
+ ]
130
+ },
131
+ {
132
+ "cell_type": "code",
133
+ "execution_count": 13,
134
+ "id": "df50af74",
135
+ "metadata": {},
136
+ "outputs": [],
137
+ "source": [
138
+ "with open(\"data/original/text_0.txt\", 'r', encoding='utf-8') as f:\n",
139
+ " lines = f.read().split('\\n')\n",
140
+ "lines = pd.DataFrame(lines)"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "execution_count": 14,
146
+ "id": "8ddb0085",
147
+ "metadata": {},
148
+ "outputs": [
149
+ {
150
+ "name": "stdout",
151
+ "output_type": "stream",
152
+ "text": [
153
+ "Passed\n"
154
+ ]
155
+ }
156
+ ],
157
+ "source": [
158
+ "assert lines.shape==(10000,1)\n",
159
+ "print(\"Passed\")"
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "markdown",
164
+ "id": "1a65b268",
165
+ "metadata": {},
166
+ "source": [
167
+ "## SQuAD Data\n",
168
+ "In the following, we download the SQuAD dataset from huggingface (https://huggingface.co/datasets/squad). It was initially provided by Rajpurkar et al. from Stanford University.\n",
169
+ "\n",
170
+ "We again load the dataset and store it in chunks of 1000 into files."
171
+ ]
172
+ },
173
+ {
174
+ "cell_type": "code",
175
+ "execution_count": 6,
176
+ "id": "6750ce6e",
177
+ "metadata": {},
178
+ "outputs": [
179
+ {
180
+ "ename": "AssertionError",
181
+ "evalue": "",
182
+ "output_type": "error",
183
+ "traceback": [
184
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
185
+ "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)",
186
+ "Cell \u001b[0;32mIn [6], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m dataset \u001b[38;5;241m=\u001b[39m load_dataset(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msquad\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
187
+ "File \u001b[0;32m~/anaconda3/envs/myenv/lib/python3.10/site-packages/datasets/load.py:1670\u001b[0m, in \u001b[0;36mload_dataset\u001b[0;34m(path, name, data_dir, data_files, split, cache_dir, features, download_config, download_mode, ignore_verifications, keep_in_memory, save_infos, revision, use_auth_token, task, streaming, **config_kwargs)\u001b[0m\n\u001b[1;32m 1667\u001b[0m ignore_verifications \u001b[38;5;241m=\u001b[39m ignore_verifications \u001b[38;5;129;01mor\u001b[39;00m save_infos\n\u001b[1;32m 1669\u001b[0m \u001b[38;5;66;03m# Create a dataset builder\u001b[39;00m\n\u001b[0;32m-> 1670\u001b[0m builder_instance \u001b[38;5;241m=\u001b[39m \u001b[43mload_dataset_builder\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1671\u001b[0m \u001b[43m \u001b[49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1672\u001b[0m \u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1673\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1674\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata_files\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_files\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1675\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1676\u001b[0m \u001b[43m \u001b[49m\u001b[43mfeatures\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfeatures\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1677\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1678\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1679\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1680\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_auth_token\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_auth_token\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1681\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mconfig_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1682\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1684\u001b[0m \u001b[38;5;66;03m# Return iterable dataset in case of streaming\u001b[39;00m\n\u001b[1;32m 1685\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m streaming:\n",
188
+ "File \u001b[0;32m~/anaconda3/envs/myenv/lib/python3.10/site-packages/datasets/load.py:1447\u001b[0m, in \u001b[0;36mload_dataset_builder\u001b[0;34m(path, name, data_dir, data_files, cache_dir, features, download_config, download_mode, revision, use_auth_token, **config_kwargs)\u001b[0m\n\u001b[1;32m 1445\u001b[0m download_config \u001b[38;5;241m=\u001b[39m download_config\u001b[38;5;241m.\u001b[39mcopy() \u001b[38;5;28;01mif\u001b[39;00m download_config \u001b[38;5;28;01melse\u001b[39;00m DownloadConfig()\n\u001b[1;32m 1446\u001b[0m download_config\u001b[38;5;241m.\u001b[39muse_auth_token \u001b[38;5;241m=\u001b[39m use_auth_token\n\u001b[0;32m-> 1447\u001b[0m dataset_module \u001b[38;5;241m=\u001b[39m \u001b[43mdataset_module_factory\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1448\u001b[0m \u001b[43m \u001b[49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1449\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1450\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1451\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1452\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1453\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata_files\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_files\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1454\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1456\u001b[0m \u001b[38;5;66;03m# Get dataset builder class from the processing script\u001b[39;00m\n\u001b[1;32m 1457\u001b[0m builder_cls \u001b[38;5;241m=\u001b[39m import_main_class(dataset_module\u001b[38;5;241m.\u001b[39mmodule_path)\n",
189
+ "File \u001b[0;32m~/anaconda3/envs/myenv/lib/python3.10/site-packages/datasets/load.py:1172\u001b[0m, in \u001b[0;36mdataset_module_factory\u001b[0;34m(path, revision, download_config, download_mode, dynamic_modules_path, data_dir, data_files, **download_kwargs)\u001b[0m\n\u001b[1;32m 1167\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(e1, \u001b[38;5;167;01mFileNotFoundError\u001b[39;00m):\n\u001b[1;32m 1168\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mFileNotFoundError\u001b[39;00m(\n\u001b[1;32m 1169\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCouldn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt find a dataset script at \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mrelative_to_absolute_path(combined_path)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m or any data file in the same directory. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1170\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCouldn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt find \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpath\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m on the Hugging Face Hub either: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(e1)\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00me1\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1171\u001b[0m ) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28mNone\u001b[39m\n\u001b[0;32m-> 1172\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e1 \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28mNone\u001b[39m\n\u001b[1;32m 1173\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1174\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mFileNotFoundError\u001b[39;00m(\n\u001b[1;32m 1175\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCouldn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt find a dataset script at \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mrelative_to_absolute_path(combined_path)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m or any data file in the same directory.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1176\u001b[0m )\n",
190
+ "File \u001b[0;32m~/anaconda3/envs/myenv/lib/python3.10/site-packages/datasets/load.py:1151\u001b[0m, in \u001b[0;36mdataset_module_factory\u001b[0;34m(path, revision, download_config, download_mode, dynamic_modules_path, data_dir, data_files, **download_kwargs)\u001b[0m\n\u001b[1;32m 1143\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m HubDatasetModuleFactoryWithScript(\n\u001b[1;32m 1144\u001b[0m path,\n\u001b[1;32m 1145\u001b[0m revision\u001b[38;5;241m=\u001b[39mrevision,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1148\u001b[0m dynamic_modules_path\u001b[38;5;241m=\u001b[39mdynamic_modules_path,\n\u001b[1;32m 1149\u001b[0m )\u001b[38;5;241m.\u001b[39mget_module()\n\u001b[1;32m 1150\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1151\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mHubDatasetModuleFactoryWithoutScript\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1152\u001b[0m \u001b[43m \u001b[49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1153\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1154\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1155\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata_files\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_files\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1156\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1157\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1158\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mget_module()\n\u001b[1;32m 1159\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e1: \u001b[38;5;66;03m# noqa: all the attempts failed, before raising the error we should check if the module is already cached.\u001b[39;00m\n\u001b[1;32m 1160\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n",
191
+ "File \u001b[0;32m~/anaconda3/envs/myenv/lib/python3.10/site-packages/datasets/load.py:744\u001b[0m, in \u001b[0;36mHubDatasetModuleFactoryWithoutScript.__init__\u001b[0;34m(self, name, revision, data_dir, data_files, download_config, download_mode)\u001b[0m\n\u001b[1;32m 742\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdownload_config \u001b[38;5;241m=\u001b[39m download_config \u001b[38;5;129;01mor\u001b[39;00m DownloadConfig()\n\u001b[1;32m 743\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdownload_mode \u001b[38;5;241m=\u001b[39m download_mode\n\u001b[0;32m--> 744\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mname\u001b[38;5;241m.\u001b[39mcount(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 745\u001b[0m increase_load_count(name, resource_type\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdataset\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
192
+ "\u001b[0;31mAssertionError\u001b[0m: "
193
+ ]
194
+ }
195
+ ],
196
+ "source": [
197
+ "dataset = load_dataset(\"squad\")"
198
+ ]
199
+ },
200
+ {
201
+ "cell_type": "code",
202
+ "execution_count": null,
203
+ "id": "65a7ee23",
204
+ "metadata": {},
205
+ "outputs": [
206
+ {
207
+ "ename": "",
208
+ "evalue": "",
209
+ "output_type": "error",
210
+ "traceback": [
211
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
212
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
213
+ ]
214
+ },
215
+ {
216
+ "ename": "",
217
+ "evalue": "",
218
+ "output_type": "error",
219
+ "traceback": [
220
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
221
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
222
+ ]
223
+ }
224
+ ],
225
+ "source": [
226
+ "os.mkdir(\"data/training_squad\")\n",
227
+ "os.mkdir(\"data/test_squad\")"
228
+ ]
229
+ },
230
+ {
231
+ "cell_type": "code",
232
+ "execution_count": null,
233
+ "id": "f6ebf63e",
234
+ "metadata": {},
235
+ "outputs": [
236
+ {
237
+ "ename": "",
238
+ "evalue": "",
239
+ "output_type": "error",
240
+ "traceback": [
241
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
242
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
243
+ ]
244
+ },
245
+ {
246
+ "ename": "",
247
+ "evalue": "",
248
+ "output_type": "error",
249
+ "traceback": [
250
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
251
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
252
+ ]
253
+ }
254
+ ],
255
+ "source": [
256
+ "# we already have a training and test split. Each sample has an id, title, context, question and answers.\n",
257
+ "dataset"
258
+ ]
259
+ },
260
+ {
261
+ "cell_type": "code",
262
+ "execution_count": null,
263
+ "id": "f67ae448",
264
+ "metadata": {},
265
+ "outputs": [
266
+ {
267
+ "ename": "",
268
+ "evalue": "",
269
+ "output_type": "error",
270
+ "traceback": [
271
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
272
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
273
+ ]
274
+ },
275
+ {
276
+ "ename": "",
277
+ "evalue": "",
278
+ "output_type": "error",
279
+ "traceback": [
280
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
281
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
282
+ ]
283
+ }
284
+ ],
285
+ "source": [
286
+ "# answers are provided like that - we need to extract answer_end for the model\n",
287
+ "dataset['train']['answers'][0]"
288
+ ]
289
+ },
290
+ {
291
+ "cell_type": "code",
292
+ "execution_count": null,
293
+ "id": "101cd650",
294
+ "metadata": {},
295
+ "outputs": [
296
+ {
297
+ "ename": "",
298
+ "evalue": "",
299
+ "output_type": "error",
300
+ "traceback": [
301
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
302
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
303
+ ]
304
+ },
305
+ {
306
+ "ename": "",
307
+ "evalue": "",
308
+ "output_type": "error",
309
+ "traceback": [
310
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
311
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
312
+ ]
313
+ }
314
+ ],
315
+ "source": [
316
+ "# column contains the split (either train or validation), save_dir is the directory\n",
317
+ "def save_samples(column, save_dir):\n",
318
+ " text = []\n",
319
+ " i = 0\n",
320
+ "\n",
321
+ " for sample in tqdm(dataset[column]):\n",
322
+ " \n",
323
+ " # preprocess the context and question by removing the newlines\n",
324
+ " context = sample['context'].replace('\\n','')\n",
325
+ " question = sample['question'].replace('\\n','')\n",
326
+ "\n",
327
+ " # get the answer as text and start character index\n",
328
+ " answer_text = sample['answers']['text'][0]\n",
329
+ " answer_start = str(sample['answers']['answer_start'][0])\n",
330
+ " \n",
331
+ " text.append([context, question, answer_text, answer_start])\n",
332
+ "\n",
333
+ " # we choose chunks of 1000\n",
334
+ " if len(text) == 1000:\n",
335
+ " with open(f\"data/{save_dir}/text_{i}.txt\", 'w', encoding='utf-8') as f:\n",
336
+ " f.write(\"\\n\".join([\"\\t\".join(t) for t in text]))\n",
337
+ " text = []\n",
338
+ " i += 1\n",
339
+ "\n",
340
+ " # save remaining\n",
341
+ " with open(f\"data/{save_dir}/text_{i}.txt\", 'w', encoding='utf-8') as f:\n",
342
+ " f.write(\"\\n\".join([\"\\t\".join(t) for t in text]))\n",
343
+ "\n",
344
+ "save_samples(\"train\", \"training_squad\")\n",
345
+ "save_samples(\"validation\", \"test_squad\")\n",
346
+ " "
347
+ ]
348
+ },
349
+ {
350
+ "cell_type": "markdown",
351
+ "id": "67044d13",
352
+ "metadata": {
353
+ "collapsed": false,
354
+ "jupyter": {
355
+ "outputs_hidden": false
356
+ }
357
+ },
358
+ "source": [
359
+ "### Testing\n",
360
+ "If we load a file, we should get a file with 10000 lines and 4 columns\n",
361
+ "\n",
362
+ "Also, we want to assure the correct interval. Hence, the second test."
363
+ ]
364
+ },
365
+ {
366
+ "cell_type": "code",
367
+ "execution_count": null,
368
+ "id": "446281cf",
369
+ "metadata": {},
370
+ "outputs": [
371
+ {
372
+ "ename": "",
373
+ "evalue": "",
374
+ "output_type": "error",
375
+ "traceback": [
376
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
377
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
378
+ ]
379
+ },
380
+ {
381
+ "ename": "",
382
+ "evalue": "",
383
+ "output_type": "error",
384
+ "traceback": [
385
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
386
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
387
+ ]
388
+ }
389
+ ],
390
+ "source": [
391
+ "with open(\"data/training_squad/text_0.txt\", 'r', encoding='utf-8') as f:\n",
392
+ " lines = f.read().split('\\n')\n",
393
+ " \n",
394
+ "lines = pd.DataFrame([line.split(\"\\t\") for line in lines], columns=[\"context\", \"question\", \"answer\", \"answer_start\"])"
395
+ ]
396
+ },
397
+ {
398
+ "cell_type": "code",
399
+ "execution_count": null,
400
+ "id": "ccd5c650",
401
+ "metadata": {},
402
+ "outputs": [
403
+ {
404
+ "ename": "",
405
+ "evalue": "",
406
+ "output_type": "error",
407
+ "traceback": [
408
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
409
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
410
+ ]
411
+ },
412
+ {
413
+ "ename": "",
414
+ "evalue": "",
415
+ "output_type": "error",
416
+ "traceback": [
417
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
418
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
419
+ ]
420
+ }
421
+ ],
422
+ "source": [
423
+ "assert lines.shape==(1000,4)\n",
424
+ "print(\"Passed\")"
425
+ ]
426
+ },
427
+ {
428
+ "cell_type": "code",
429
+ "execution_count": null,
430
+ "id": "2c9e4b70",
431
+ "metadata": {},
432
+ "outputs": [
433
+ {
434
+ "ename": "",
435
+ "evalue": "",
436
+ "output_type": "error",
437
+ "traceback": [
438
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
439
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
440
+ ]
441
+ },
442
+ {
443
+ "ename": "",
444
+ "evalue": "",
445
+ "output_type": "error",
446
+ "traceback": [
447
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
448
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
449
+ ]
450
+ }
451
+ ],
452
+ "source": [
453
+ "# we assert that we have the right interval\n",
454
+ "for ind, line in lines.iterrows():\n",
455
+ " sample = line\n",
456
+ " answer_start = int(sample['answer_start'])\n",
457
+ " assert sample['context'][answer_start:answer_start+len(sample['answer'])] == sample['answer']\n",
458
+ "print(\"Passed\")"
459
+ ]
460
+ },
461
+ {
462
+ "cell_type": "markdown",
463
+ "id": "02265ace",
464
+ "metadata": {},
465
+ "source": [
466
+ "## Natural Questions Dataset\n",
467
+ "* Download from https://ai.google.com/research/NaturalQuestions via gsutil (the one from huggingface has 134.92GB, the one from google cloud is in archives)\n",
468
+ "* Use gunzip to get some samples - we then get `.jsonl`files\n",
469
+ "* The dataset is a lot more messy, as it is just wikipedia articles with all web artifacts\n",
470
+ " * I cleaned the html tags\n",
471
+ " * Also I chose a random interval (containing the answer) from the dataset\n",
472
+ " * We can't send the whole text into the model anyways"
473
+ ]
474
+ },
475
+ {
476
+ "cell_type": "code",
477
+ "execution_count": null,
478
+ "id": "f3bce0c1",
479
+ "metadata": {},
480
+ "outputs": [
481
+ {
482
+ "ename": "",
483
+ "evalue": "",
484
+ "output_type": "error",
485
+ "traceback": [
486
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
487
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
488
+ ]
489
+ },
490
+ {
491
+ "ename": "",
492
+ "evalue": "",
493
+ "output_type": "error",
494
+ "traceback": [
495
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
496
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
497
+ ]
498
+ }
499
+ ],
500
+ "source": [
501
+ "from pathlib import Path\n",
502
+ "paths = [str(x) for x in Path('data/natural_questions/v1.0/train/').glob('**/*.jsonl')]"
503
+ ]
504
+ },
505
+ {
506
+ "cell_type": "code",
507
+ "execution_count": null,
508
+ "id": "e9c58c00",
509
+ "metadata": {},
510
+ "outputs": [
511
+ {
512
+ "ename": "",
513
+ "evalue": "",
514
+ "output_type": "error",
515
+ "traceback": [
516
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
517
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
518
+ ]
519
+ },
520
+ {
521
+ "ename": "",
522
+ "evalue": "",
523
+ "output_type": "error",
524
+ "traceback": [
525
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
526
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
527
+ ]
528
+ }
529
+ ],
530
+ "source": [
531
+ "os.mkdir(\"data/natural_questions_train\")"
532
+ ]
533
+ },
534
+ {
535
+ "cell_type": "code",
536
+ "execution_count": null,
537
+ "id": "0ed7ba6c",
538
+ "metadata": {},
539
+ "outputs": [
540
+ {
541
+ "ename": "",
542
+ "evalue": "",
543
+ "output_type": "error",
544
+ "traceback": [
545
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
546
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
547
+ ]
548
+ },
549
+ {
550
+ "ename": "",
551
+ "evalue": "",
552
+ "output_type": "error",
553
+ "traceback": [
554
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
555
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
556
+ ]
557
+ }
558
+ ],
559
+ "source": [
560
+ "import re\n",
561
+ "\n",
562
+ "# clean html tags\n",
563
+ "CLEANR = re.compile('<.+?>')\n",
564
+ "# clean multiple spaces\n",
565
+ "CLEANMULTSPACE = re.compile('(\\s)+')\n",
566
+ "\n",
567
+ "# the function takes an html documents and removes artifacts\n",
568
+ "def cleanhtml(raw_html):\n",
569
+ " # tags\n",
570
+ " cleantext = re.sub(CLEANR, '', raw_html)\n",
571
+ " # newlines\n",
572
+ " cleantext = cleantext.replace(\"\\n\", '')\n",
573
+ " # tabs\n",
574
+ " cleantext = cleantext.replace(\"\\t\", '')\n",
575
+ " # character encodings\n",
576
+ " cleantext = cleantext.replace(\"&#39;\", \"'\")\n",
577
+ " cleantext = cleantext.replace(\"&amp;\", \"'\")\n",
578
+ " cleantext = cleantext.replace(\"&quot;\", '\"')\n",
579
+ " # multiple spaces\n",
580
+ " cleantext = re.sub(CLEANMULTSPACE, ' ', cleantext)\n",
581
+ " # documents end with this tags, if it is present in the string, cut it off\n",
582
+ " idx = cleantext.find(\"<!-- NewPP limit\")\n",
583
+ " if idx > -1:\n",
584
+ " cleantext = cleantext[:idx]\n",
585
+ " return cleantext.strip()"
586
+ ]
587
+ },
588
+ {
589
+ "cell_type": "code",
590
+ "execution_count": null,
591
+ "id": "66ca19ac",
592
+ "metadata": {},
593
+ "outputs": [
594
+ {
595
+ "ename": "",
596
+ "evalue": "",
597
+ "output_type": "error",
598
+ "traceback": [
599
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
600
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
601
+ ]
602
+ },
603
+ {
604
+ "ename": "",
605
+ "evalue": "",
606
+ "output_type": "error",
607
+ "traceback": [
608
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
609
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
610
+ ]
611
+ }
612
+ ],
613
+ "source": [
614
+ "import json\n",
615
+ "\n",
616
+ "# file count\n",
617
+ "i = 0\n",
618
+ "data = []\n",
619
+ "\n",
620
+ "# iterate over all json files\n",
621
+ "for path in paths:\n",
622
+ " print(path)\n",
623
+ " # read file and store as list (this requires much memory, as the files are huge)\n",
624
+ " with open(path, 'r') as json_file:\n",
625
+ " json_list = list(json_file)\n",
626
+ " \n",
627
+ " # process every context, question, answer pair\n",
628
+ " for json_str in json_list:\n",
629
+ " result = json.loads(json_str)\n",
630
+ "\n",
631
+ " # append a question mark - SQuAD questions end with a qm too\n",
632
+ " question = result['question_text'] + \"?\"\n",
633
+ " \n",
634
+ " # some question do not contain an answer - we do not need them\n",
635
+ " if(len(result['annotations'][0]['short_answers'])==0):\n",
636
+ " continue\n",
637
+ "\n",
638
+ " # get true start/end byte\n",
639
+ " true_start = result['annotations'][0]['short_answers'][0]['start_byte']\n",
640
+ " true_end = result['annotations'][0]['short_answers'][0]['end_byte']\n",
641
+ "\n",
642
+ " # convert to bytes\n",
643
+ " byte_encoding = bytes(result['document_html'], encoding='utf-8')\n",
644
+ " \n",
645
+ " # the document is the whole wikipedia article, we randomly choose an appropriate part (containing the\n",
646
+ " # answer): we have 512 tokens as the input for the model - 4000 bytes lead to a good length\n",
647
+ " max_back = 3500 if true_start >= 3500 else true_start\n",
648
+ " first = random.randint(int(true_start)-max_back, int(true_start))\n",
649
+ " end = first + 3500 + true_end - true_start\n",
650
+ " \n",
651
+ " # get chosen context\n",
652
+ " cleanbytes = byte_encoding[first:end]\n",
653
+ " # decode back to text - if our end byte is the middle of a word, we ignore it and cut it off\n",
654
+ " cleantext = bytes.decode(cleanbytes, errors='ignore')\n",
655
+ " # clean html tags\n",
656
+ " cleantext = cleanhtml(cleantext)\n",
657
+ "\n",
658
+ " # find the true answer\n",
659
+ " answer_start = cleanbytes.find(byte_encoding[true_start:true_end])\n",
660
+ " true_answer = bytes.decode(cleanbytes[answer_start:answer_start+(true_end-true_start)])\n",
661
+ " \n",
662
+ " # clean html tags\n",
663
+ " true_answer = cleanhtml(true_answer)\n",
664
+ " \n",
665
+ " start_ind = cleantext.find(true_answer)\n",
666
+ " \n",
667
+ " # If cleaning the string makes the answer not findable skip it\n",
668
+ " # this hardly ever happens, except if there is an emense amount of web artifacts\n",
669
+ " if start_ind == -1:\n",
670
+ " continue\n",
671
+ " \n",
672
+ " data.append([cleantext, question, true_answer, str(start_ind)])\n",
673
+ "\n",
674
+ " if len(data) == 1000:\n",
675
+ " with open(f\"data/natural_questions_train/text_{i}.txt\", 'w', encoding='utf-8') as f:\n",
676
+ " f.write(\"\\n\".join([\"\\t\".join(t) for t in data]))\n",
677
+ " i += 1\n",
678
+ " data = []\n",
679
+ "with open(f\"data/natural_questions_train/text_{i}.txt\", 'w', encoding='utf-8') as f:\n",
680
+ " f.write(\"\\n\".join([\"\\t\".join(t) for t in data]))"
681
+ ]
682
+ },
683
+ {
684
+ "cell_type": "markdown",
685
+ "id": "30f26b4e",
686
+ "metadata": {},
687
+ "source": [
688
+ "### Testing\n",
689
+ "In the following, we first check if the shape of the file is correct.\n",
690
+ "\n",
691
+ "Then we iterate over the file and check if the answers according to the file are the same as in the original file."
692
+ ]
693
+ },
694
+ {
695
+ "cell_type": "code",
696
+ "execution_count": null,
697
+ "id": "490ac0db",
698
+ "metadata": {},
699
+ "outputs": [
700
+ {
701
+ "ename": "",
702
+ "evalue": "",
703
+ "output_type": "error",
704
+ "traceback": [
705
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
706
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
707
+ ]
708
+ },
709
+ {
710
+ "ename": "",
711
+ "evalue": "",
712
+ "output_type": "error",
713
+ "traceback": [
714
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
715
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
716
+ ]
717
+ }
718
+ ],
719
+ "source": [
720
+ "with open(\"data/natural_questions_train/text_0.txt\", 'r', encoding='utf-8') as f:\n",
721
+ " lines = f.read().split('\\n')\n",
722
+ " \n",
723
+ "lines = pd.DataFrame([line.split(\"\\t\") for line in lines], columns=[\"context\", \"question\", \"answer\", \"answer_start\"])"
724
+ ]
725
+ },
726
+ {
727
+ "cell_type": "code",
728
+ "execution_count": null,
729
+ "id": "0d7cc3ee",
730
+ "metadata": {},
731
+ "outputs": [
732
+ {
733
+ "ename": "",
734
+ "evalue": "",
735
+ "output_type": "error",
736
+ "traceback": [
737
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
738
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
739
+ ]
740
+ },
741
+ {
742
+ "ename": "",
743
+ "evalue": "",
744
+ "output_type": "error",
745
+ "traceback": [
746
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
747
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
748
+ ]
749
+ }
750
+ ],
751
+ "source": [
752
+ "assert lines.shape == (1000, 4)\n",
753
+ "print(\"Passed\")"
754
+ ]
755
+ },
756
+ {
757
+ "cell_type": "code",
758
+ "execution_count": null,
759
+ "id": "0fd8a854",
760
+ "metadata": {},
761
+ "outputs": [
762
+ {
763
+ "ename": "",
764
+ "evalue": "",
765
+ "output_type": "error",
766
+ "traceback": [
767
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
768
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
769
+ ]
770
+ },
771
+ {
772
+ "ename": "",
773
+ "evalue": "",
774
+ "output_type": "error",
775
+ "traceback": [
776
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
777
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
778
+ ]
779
+ }
780
+ ],
781
+ "source": [
782
+ "with open(\"data/natural_questions/v1.0/train/nq-train-00.jsonl\", 'r') as json_file:\n",
783
+ " json_list = list(json_file)[:500]\n",
784
+ "del json_file"
785
+ ]
786
+ },
787
+ {
788
+ "cell_type": "code",
789
+ "execution_count": null,
790
+ "id": "170bff30",
791
+ "metadata": {},
792
+ "outputs": [
793
+ {
794
+ "ename": "",
795
+ "evalue": "",
796
+ "output_type": "error",
797
+ "traceback": [
798
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
799
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
800
+ ]
801
+ },
802
+ {
803
+ "ename": "",
804
+ "evalue": "",
805
+ "output_type": "error",
806
+ "traceback": [
807
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
808
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
809
+ ]
810
+ }
811
+ ],
812
+ "source": [
813
+ "lines_index = 0\n",
814
+ "for i in range(len(json_list)):\n",
815
+ " result = json.loads(json_list[i])\n",
816
+ " \n",
817
+ " if(len(result['annotations'][0]['short_answers'])==0):\n",
818
+ " pass\n",
819
+ " else: \n",
820
+ " # assert that the question text is the same\n",
821
+ " assert result['question_text'] + \"?\" == lines.loc[lines_index, 'question']\n",
822
+ " true_start = result['annotations'][0]['short_answers'][0]['start_byte']\n",
823
+ " true_end = result['annotations'][0]['short_answers'][0]['end_byte']\n",
824
+ " true_answer = bytes.decode(bytes(result['document_html'], encoding='utf-8')[true_start:true_end])\n",
825
+ " \n",
826
+ " processed_answer = lines.loc[lines_index, 'answer']\n",
827
+ " # assert that the answer is the same\n",
828
+ " assert cleanhtml(true_answer) == processed_answer\n",
829
+ " \n",
830
+ " start_ind = int(lines.loc[lines_index, 'answer_start'])\n",
831
+ " # assert that the answer (according to the index) is the same\n",
832
+ " assert cleanhtml(true_answer) == lines.loc[lines_index, 'context'][start_ind:start_ind+len(processed_answer)]\n",
833
+ " \n",
834
+ " lines_index += 1\n",
835
+ " \n",
836
+ " if lines_index == len(lines):\n",
837
+ " break\n",
838
+ "print(\"Passed\")"
839
+ ]
840
+ },
841
+ {
842
+ "cell_type": "markdown",
843
+ "id": "78e6e737",
844
+ "metadata": {},
845
+ "source": [
846
+ "## Hotpot QA"
847
+ ]
848
+ },
849
+ {
850
+ "cell_type": "code",
851
+ "execution_count": null,
852
+ "id": "27efcc8c",
853
+ "metadata": {},
854
+ "outputs": [
855
+ {
856
+ "ename": "",
857
+ "evalue": "",
858
+ "output_type": "error",
859
+ "traceback": [
860
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
861
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
862
+ ]
863
+ },
864
+ {
865
+ "ename": "",
866
+ "evalue": "",
867
+ "output_type": "error",
868
+ "traceback": [
869
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
870
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
871
+ ]
872
+ }
873
+ ],
874
+ "source": [
875
+ "ds = load_dataset(\"hotpot_qa\", 'fullwiki')"
876
+ ]
877
+ },
878
+ {
879
+ "cell_type": "code",
880
+ "execution_count": null,
881
+ "id": "1493f21f",
882
+ "metadata": {},
883
+ "outputs": [
884
+ {
885
+ "ename": "",
886
+ "evalue": "",
887
+ "output_type": "error",
888
+ "traceback": [
889
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
890
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
891
+ ]
892
+ },
893
+ {
894
+ "ename": "",
895
+ "evalue": "",
896
+ "output_type": "error",
897
+ "traceback": [
898
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
899
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
900
+ ]
901
+ }
902
+ ],
903
+ "source": [
904
+ "ds"
905
+ ]
906
+ },
907
+ {
908
+ "cell_type": "code",
909
+ "execution_count": null,
910
+ "id": "2a047946",
911
+ "metadata": {},
912
+ "outputs": [
913
+ {
914
+ "ename": "",
915
+ "evalue": "",
916
+ "output_type": "error",
917
+ "traceback": [
918
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
919
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
920
+ ]
921
+ },
922
+ {
923
+ "ename": "",
924
+ "evalue": "",
925
+ "output_type": "error",
926
+ "traceback": [
927
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
928
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
929
+ ]
930
+ }
931
+ ],
932
+ "source": [
933
+ "os.mkdir('data/hotpotqa_training')\n",
934
+ "os.mkdir('data/hotpotqa_test')"
935
+ ]
936
+ },
937
+ {
938
+ "cell_type": "code",
939
+ "execution_count": null,
940
+ "id": "e65b6485",
941
+ "metadata": {},
942
+ "outputs": [
943
+ {
944
+ "ename": "",
945
+ "evalue": "",
946
+ "output_type": "error",
947
+ "traceback": [
948
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
949
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
950
+ ]
951
+ },
952
+ {
953
+ "ename": "",
954
+ "evalue": "",
955
+ "output_type": "error",
956
+ "traceback": [
957
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
958
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
959
+ ]
960
+ }
961
+ ],
962
+ "source": [
963
+ "# column contains the split (either train or validation), save_dir is the directory\n",
964
+ "def save_samples(column, save_dir):\n",
965
+ " text = []\n",
966
+ " i = 0\n",
967
+ "\n",
968
+ " for sample in tqdm(ds[column]):\n",
969
+ " \n",
970
+ " # preprocess the context and question by removing the newlines\n",
971
+ " context = sample['context']['sentences']\n",
972
+ " context = \" \".join([\"\".join(sentence) for sentence in context])\n",
973
+ " question = sample['question'].replace('\\n','')\n",
974
+ " \n",
975
+ " # get the answer as text and start character index\n",
976
+ " answer_text = sample['answer']\n",
977
+ " answer_start = context.find(answer_text)\n",
978
+ " if answer_start == -1:\n",
979
+ " continue\n",
980
+ " \n",
981
+ " \n",
982
+ " \n",
983
+ " if answer_start > 1500:\n",
984
+ " first = random.randint(answer_start-1500, answer_start)\n",
985
+ " end = first + 1500 + len(answer_text)\n",
986
+ " \n",
987
+ " context = context[first:end+1]\n",
988
+ " answer_start = context.find(answer_text)\n",
989
+ " \n",
990
+ " if answer_start == -1:continue\n",
991
+ " \n",
992
+ " text.append([context, question, answer_text, str(answer_start)])\n",
993
+ "\n",
994
+ " # we choose chunks of 1000\n",
995
+ " if len(text) == 1000:\n",
996
+ " with open(f\"data/{save_dir}/text_{i}.txt\", 'w', encoding='utf-8') as f:\n",
997
+ " f.write(\"\\n\".join([\"\\t\".join(t) for t in text]))\n",
998
+ " text = []\n",
999
+ " i += 1\n",
1000
+ "\n",
1001
+ " # save remaining\n",
1002
+ " with open(f\"data/{save_dir}/text_{i}.txt\", 'w', encoding='utf-8') as f:\n",
1003
+ " f.write(\"\\n\".join([\"\\t\".join(t) for t in text]))\n",
1004
+ "\n",
1005
+ "save_samples(\"train\", \"hotpotqa_training\")\n",
1006
+ "save_samples(\"validation\", \"hotpotqa_test\")"
1007
+ ]
1008
+ },
1009
+ {
1010
+ "cell_type": "markdown",
1011
+ "id": "97cc358f",
1012
+ "metadata": {},
1013
+ "source": [
1014
+ "## Testing"
1015
+ ]
1016
+ },
1017
+ {
1018
+ "cell_type": "code",
1019
+ "execution_count": null,
1020
+ "id": "f321483c",
1021
+ "metadata": {},
1022
+ "outputs": [
1023
+ {
1024
+ "ename": "",
1025
+ "evalue": "",
1026
+ "output_type": "error",
1027
+ "traceback": [
1028
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
1029
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
1030
+ ]
1031
+ },
1032
+ {
1033
+ "ename": "",
1034
+ "evalue": "",
1035
+ "output_type": "error",
1036
+ "traceback": [
1037
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
1038
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
1039
+ ]
1040
+ }
1041
+ ],
1042
+ "source": [
1043
+ "with open(\"data/hotpotqa_training/text_0.txt\", 'r', encoding='utf-8') as f:\n",
1044
+ " lines = f.read().split('\\n')\n",
1045
+ " \n",
1046
+ "lines = pd.DataFrame([line.split(\"\\t\") for line in lines], columns=[\"context\", \"question\", \"answer\", \"answer_start\"])"
1047
+ ]
1048
+ },
1049
+ {
1050
+ "cell_type": "code",
1051
+ "execution_count": null,
1052
+ "id": "72a96e78",
1053
+ "metadata": {},
1054
+ "outputs": [
1055
+ {
1056
+ "ename": "",
1057
+ "evalue": "",
1058
+ "output_type": "error",
1059
+ "traceback": [
1060
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
1061
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
1062
+ ]
1063
+ },
1064
+ {
1065
+ "ename": "",
1066
+ "evalue": "",
1067
+ "output_type": "error",
1068
+ "traceback": [
1069
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
1070
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
1071
+ ]
1072
+ }
1073
+ ],
1074
+ "source": [
1075
+ "assert lines.shape == (1000, 4)\n",
1076
+ "print(\"Passed\")"
1077
+ ]
1078
+ },
1079
+ {
1080
+ "cell_type": "code",
1081
+ "execution_count": null,
1082
+ "id": "c32c2f16",
1083
+ "metadata": {},
1084
+ "outputs": [
1085
+ {
1086
+ "ename": "",
1087
+ "evalue": "",
1088
+ "output_type": "error",
1089
+ "traceback": [
1090
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
1091
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
1092
+ ]
1093
+ },
1094
+ {
1095
+ "ename": "",
1096
+ "evalue": "",
1097
+ "output_type": "error",
1098
+ "traceback": [
1099
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
1100
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
1101
+ ]
1102
+ }
1103
+ ],
1104
+ "source": [
1105
+ "# we assert that we have the right interval\n",
1106
+ "for ind, line in lines.iterrows():\n",
1107
+ " sample = line\n",
1108
+ " answer_start = int(sample['answer_start'])\n",
1109
+ " assert sample['context'][answer_start:answer_start+len(sample['answer'])] == sample['answer']\n",
1110
+ "print(\"Passed\")"
1111
+ ]
1112
+ },
1113
+ {
1114
+ "cell_type": "code",
1115
+ "execution_count": null,
1116
+ "id": "bc36fe7d",
1117
+ "metadata": {},
1118
+ "outputs": [
1119
+ {
1120
+ "ename": "",
1121
+ "evalue": "",
1122
+ "output_type": "error",
1123
+ "traceback": [
1124
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
1125
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
1126
+ ]
1127
+ },
1128
+ {
1129
+ "ename": "",
1130
+ "evalue": "",
1131
+ "output_type": "error",
1132
+ "traceback": [
1133
+ "\u001b[1;31mnotebook controller is DISPOSED. \n",
1134
+ "\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
1135
+ ]
1136
+ }
1137
+ ],
1138
+ "source": []
1139
+ }
1140
+ ],
1141
+ "metadata": {
1142
+ "kernelspec": {
1143
+ "display_name": "Python 3 (ipykernel)",
1144
+ "language": "python",
1145
+ "name": "python3"
1146
+ },
1147
+ "language_info": {
1148
+ "codemirror_mode": {
1149
+ "name": "ipython",
1150
+ "version": 3
1151
+ },
1152
+ "file_extension": ".py",
1153
+ "mimetype": "text/x-python",
1154
+ "name": "python",
1155
+ "nbconvert_exporter": "python",
1156
+ "pygments_lexer": "ipython3",
1157
+ "version": "3.10.16"
1158
+ },
1159
+ "toc": {
1160
+ "base_numbering": 1,
1161
+ "nav_menu": {},
1162
+ "number_sections": true,
1163
+ "sideBar": true,
1164
+ "skip_h1_title": false,
1165
+ "title_cell": "Table of Contents",
1166
+ "title_sidebar": "Contents",
1167
+ "toc_cell": false,
1168
+ "toc_position": {},
1169
+ "toc_section_display": true,
1170
+ "toc_window_display": false
1171
+ },
1172
+ "varInspector": {
1173
+ "cols": {
1174
+ "lenName": 16,
1175
+ "lenType": 16,
1176
+ "lenVar": 40
1177
+ },
1178
+ "kernels_config": {
1179
+ "python": {
1180
+ "delete_cmd_postfix": "",
1181
+ "delete_cmd_prefix": "del ",
1182
+ "library": "var_list.py",
1183
+ "varRefreshCmd": "print(var_dic_list())"
1184
+ },
1185
+ "r": {
1186
+ "delete_cmd_postfix": ") ",
1187
+ "delete_cmd_prefix": "rm(",
1188
+ "library": "var_list.r",
1189
+ "varRefreshCmd": "cat(var_dic_list()) "
1190
+ }
1191
+ },
1192
+ "types_to_exclude": [
1193
+ "module",
1194
+ "function",
1195
+ "builtin_function_or_method",
1196
+ "instance",
1197
+ "_Feature"
1198
+ ],
1199
+ "window_display": false
1200
+ },
1201
+ "vscode": {
1202
+ "interpreter": {
1203
+ "hash": "85bf9c14e9ba73b783ed1274d522bec79eb0b2b739090180d8ce17bb11aff4aa"
1204
+ }
1205
+ }
1206
+ },
1207
+ "nbformat": 4,
1208
+ "nbformat_minor": 5
1209
+ }
qa_model.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+ from typing import Optional
4
+ import copy
5
+ import pandas as pd
6
+
7
+ """
8
+ This module contains the implementation of the QA model. We define three different models and a dataset class.
9
+ The structure is based on the Hugging Face implementations.
10
+ https://huggingface.co/docs/transformers/model_doc/distilbert
11
+ """
12
+
13
+ class SimpleQuestionDistilBERT(nn.Module):
14
+ """
15
+ This class implements a simple version of the distilbert question answering model, following the implementation of Hugging Face,
16
+ https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/distilbert/modeling_distilbert.py#L805
17
+
18
+ It basically fine-tunes a given distilbert model. We only add one linear layer on top, which determines the start and end logits.
19
+ """
20
+ def __init__(self, distilbert, dropout=0.1):
21
+ """
22
+ Creates and initialises model
23
+ """
24
+ super(SimpleQuestionDistilBERT, self).__init__()
25
+
26
+ self.distilbert = distilbert
27
+
28
+ self.dropout = nn.Dropout(dropout)
29
+
30
+ self.classifier = nn.Linear(768, 2)
31
+
32
+ # initialise weights
33
+ def init_weights(m):
34
+ if isinstance(m, nn.Linear):
35
+ nn.init.xavier_uniform_(m.weight)
36
+ m.bias.data.fill_(0.01)
37
+ self.classifier.apply(init_weights)
38
+
39
+
40
+ def forward(self,
41
+ input_ids: Optional[torch.Tensor] = None,
42
+ attention_mask: Optional[torch.Tensor] = None,
43
+ head_mask: Optional[torch.Tensor] = None,
44
+ inputs_embeds: Optional[torch.Tensor] = None,
45
+ start_positions: Optional[torch.Tensor] = None,
46
+ end_positions: Optional[torch.Tensor] = None,
47
+ output_attentions: Optional[bool] = None,
48
+ output_hidden_states: Optional[bool] = None,
49
+ return_dict: Optional[bool] = None):
50
+ """
51
+ This function implements the forward pass of the model. It takes the input_ids and attention_mask and returns the start and end logits.
52
+ """
53
+ # make predictions on base model
54
+ distilbert_output = self.distilbert(
55
+ input_ids=input_ids,
56
+ attention_mask=attention_mask,
57
+ inputs_embeds=inputs_embeds,
58
+ output_attentions=output_attentions,
59
+ output_hidden_states=output_hidden_states,
60
+ return_dict=return_dict,
61
+ )
62
+
63
+ # retrieve hidden states
64
+ hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
65
+ hidden_states = self.dropout(hidden_states)
66
+
67
+ # make predictions on head
68
+ logits = self.classifier(hidden_states)
69
+ start_logits, end_logits = logits.split(1, dim=-1)
70
+ start_logits = start_logits.squeeze(-1).contiguous() # (bs, max_query_len)
71
+ end_logits = end_logits.squeeze(-1).contiguous() # (bs, max_query_len)
72
+
73
+ # calculate loss
74
+ total_loss = None
75
+ if start_positions is not None and end_positions is not None:
76
+ if len(start_positions.size()) > 1:
77
+ start_positions = start_positions.squeeze(-1)
78
+ if len(end_positions.size()) > 1:
79
+ end_positions = end_positions.squeeze(-1)
80
+
81
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
82
+ ignored_index = start_logits.size(1)
83
+ start_positions = start_positions.clamp(0, ignored_index)
84
+ end_positions = end_positions.clamp(0, ignored_index)
85
+
86
+ loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
87
+ start_loss = loss_fct(start_logits, start_positions)
88
+ end_loss = loss_fct(end_logits, end_positions)
89
+ total_loss = (start_loss + end_loss) / 2
90
+
91
+ return {"loss": total_loss,
92
+ "start_logits": start_logits,
93
+ "end_logits": end_logits,
94
+ "hidden_states": distilbert_output.hidden_states,
95
+ "attentions": distilbert_output.attentions}
96
+
97
+
98
+ class QuestionDistilBERT(nn.Module):
99
+ """
100
+ This class implements the distilbert question answering model. We fix all layers of the base model and only fine-tune the head.
101
+ The head consists of a transformer encoder with three layers and a classifier on top.
102
+ """
103
+ def __init__(self, distilbert, dropout=0.1):
104
+ """
105
+ Creates and initialises QuestionDIstilBERT instance
106
+ """
107
+ super(QuestionDistilBERT, self).__init__()
108
+
109
+ # fix parameters for base model
110
+ for param in distilbert.parameters():
111
+ param.requires_grad = False
112
+
113
+ self.distilbert = distilbert
114
+ self.relu = nn.ReLU()
115
+
116
+ self.dropout = nn.Dropout(dropout)
117
+ self.te = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=768, nhead=12), num_layers=3)
118
+
119
+ # create custom head
120
+ self.classifier = nn.Sequential(
121
+ nn.Dropout(dropout),
122
+ nn.ReLU(),
123
+ nn.Linear(768, 512),
124
+ nn.Dropout(dropout),
125
+ nn.ReLU(),
126
+ nn.Linear(512, 256),
127
+ nn.Dropout(dropout),
128
+ nn.ReLU(),
129
+ nn.Linear(256, 128),
130
+ nn.Dropout(dropout),
131
+ nn.ReLU(),
132
+ nn.Linear(128, 64),
133
+ nn.Dropout(dropout),
134
+ nn.ReLU(),
135
+ nn.Linear(64, 2)
136
+ )
137
+
138
+ # initialise weights of the linear layers
139
+ def init_weights(m):
140
+ if isinstance(m, nn.Linear):
141
+ nn.init.xavier_uniform_(m.weight)
142
+ m.bias.data.fill_(0.01)
143
+
144
+ self.classifier.apply(init_weights)
145
+
146
+ def forward(self,
147
+ input_ids: Optional[torch.Tensor] = None,
148
+ attention_mask: Optional[torch.Tensor] = None,
149
+ head_mask: Optional[torch.Tensor] = None,
150
+ inputs_embeds: Optional[torch.Tensor] = None,
151
+ start_positions: Optional[torch.Tensor] = None,
152
+ end_positions: Optional[torch.Tensor] = None,
153
+ output_attentions: Optional[bool] = None,
154
+ output_hidden_states: Optional[bool] = None,
155
+ return_dict: Optional[bool] = None):
156
+ """
157
+ This function implements the forward pass of the model. It takes the input_ids and attention_mask and returns the start and end logits.
158
+ """
159
+ # make predictions on base model
160
+ distilbert_output = self.distilbert(
161
+ input_ids=input_ids,
162
+ attention_mask=attention_mask,
163
+ inputs_embeds=inputs_embeds,
164
+ output_attentions=output_attentions,
165
+ output_hidden_states=output_hidden_states,
166
+ return_dict=return_dict,
167
+ )
168
+
169
+ # retrieve hidden states
170
+ hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
171
+ hidden_states = self.dropout(hidden_states)
172
+ attn_output = self.te(hidden_states)
173
+
174
+ # make predictions on head
175
+ logits = self.classifier(attn_output)
176
+ start_logits, end_logits = logits.split(1, dim=-1)
177
+ start_logits = start_logits.squeeze(-1).contiguous()
178
+ end_logits = end_logits.squeeze(-1).contiguous()
179
+
180
+ # calculate loss
181
+ total_loss = None
182
+ if start_positions is not None and end_positions is not None:
183
+ if len(start_positions.size()) > 1:
184
+ start_positions = start_positions.squeeze(-1)
185
+ if len(end_positions.size()) > 1:
186
+ end_positions = end_positions.squeeze(-1)
187
+
188
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
189
+ ignored_index = start_logits.size(1)
190
+ start_positions = start_positions.clamp(0, ignored_index)
191
+ end_positions = end_positions.clamp(0, ignored_index)
192
+
193
+ loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
194
+ start_loss = loss_fct(start_logits, start_positions)
195
+ end_loss = loss_fct(end_logits, end_positions)
196
+ total_loss = (start_loss + end_loss) / 2
197
+
198
+ return {"loss": total_loss,
199
+ "start_logits": start_logits,
200
+ "end_logits": end_logits,
201
+ "hidden_states": distilbert_output.hidden_states,
202
+ "attentions": distilbert_output.attentions}
203
+
204
+
205
+ class ReuseQuestionDistilBERT(nn.Module):
206
+ """
207
+ This class imports a model where all layers of the base distilbert model are fixed.
208
+ Instead of training a completely new head, we copy the last two layers of the base model and add a classifier on top.
209
+ """
210
+ def __init__(self, distilbert, dropout=0.15):
211
+ """
212
+ Creates and initialises QuestionDIstilBERT instance
213
+ """
214
+ super(ReuseQuestionDistilBERT, self).__init__()
215
+ self.te = copy.deepcopy(list(list(distilbert.children())[1].children())[0][-2:])
216
+ # fix parameters for base model
217
+ for param in distilbert.parameters():
218
+ param.requires_grad = False
219
+
220
+ self.distilbert = distilbert
221
+ self.relu = nn.ReLU()
222
+
223
+ self.dropout = nn.Dropout(dropout)
224
+
225
+ # create custom head
226
+ self.classifier = nn.Linear(768, 2)
227
+
228
+ def init_weights(m):
229
+ if isinstance(m, nn.Linear):
230
+ nn.init.xavier_uniform_(m.weight)
231
+ m.bias.data.fill_(0.01)
232
+ self.classifier.apply(init_weights)
233
+
234
+ def forward(self,
235
+ input_ids: Optional[torch.Tensor] = None,
236
+ attention_mask: Optional[torch.Tensor] = None,
237
+ head_mask: Optional[torch.Tensor] = None,
238
+ inputs_embeds: Optional[torch.Tensor] = None,
239
+ start_positions: Optional[torch.Tensor] = None,
240
+ end_positions: Optional[torch.Tensor] = None,
241
+ output_attentions: Optional[bool] = None,
242
+ output_hidden_states: Optional[bool] = None,
243
+ return_dict: Optional[bool] = None):
244
+ """
245
+ This function implements the forward pass of the model. It takes the input_ids and attention_mask and returns the start and end logits.
246
+ """
247
+ # make predictions on base model
248
+ distilbert_output = self.distilbert(
249
+ input_ids=input_ids,
250
+ attention_mask=attention_mask,
251
+ head_mask=head_mask,
252
+ inputs_embeds=inputs_embeds,
253
+ output_attentions=output_attentions,
254
+ output_hidden_states=output_hidden_states,
255
+ return_dict=return_dict,
256
+ )
257
+
258
+ # retrieve hidden states
259
+ hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
260
+ hidden_states = self.dropout(hidden_states)
261
+ for te in self.te:
262
+ hidden_states = te(
263
+ x=hidden_states,
264
+ attn_mask=attention_mask,
265
+ head_mask=head_mask,
266
+ output_attentions=output_attentions
267
+ )[0]
268
+ hidden_states = self.dropout(hidden_states)
269
+
270
+ # make predictions on head
271
+ logits = self.classifier(hidden_states)
272
+ start_logits, end_logits = logits.split(1, dim=-1)
273
+ start_logits = start_logits.squeeze(-1).contiguous() # (bs, max_query_len)
274
+ end_logits = end_logits.squeeze(-1).contiguous() # (bs, max_query_len)
275
+
276
+ # calculate loss
277
+ total_loss = None
278
+ if start_positions is not None and end_positions is not None:
279
+ if len(start_positions.size()) > 1:
280
+ start_positions = start_positions.squeeze(-1)
281
+ if len(end_positions.size()) > 1:
282
+ end_positions = end_positions.squeeze(-1)
283
+
284
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
285
+ ignored_index = start_logits.size(1)
286
+ start_positions = start_positions.clamp(0, ignored_index)
287
+ end_positions = end_positions.clamp(0, ignored_index)
288
+
289
+ loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
290
+ start_loss = loss_fct(start_logits, start_positions)
291
+ end_loss = loss_fct(end_logits, end_positions)
292
+ total_loss = (start_loss + end_loss) / 2
293
+
294
+ return {"loss": total_loss,
295
+ "start_logits": start_logits,
296
+ "end_logits": end_logits,
297
+ "hidden_states": distilbert_output.hidden_states,
298
+ "attentions": distilbert_output.attentions}
299
+
300
+ class Dataset(torch.utils.data.Dataset):
301
+ """
302
+ This class creates a dataset for the DistilBERT qa-model.
303
+ """
304
+ def __init__(self, squad_paths, natural_question_paths, hotpotqa_paths, tokenizer):
305
+ """
306
+ creates and initialises dataset object
307
+ """
308
+ self.paths = []
309
+ self.count = 0
310
+ if squad_paths != None:
311
+ self.paths.extend(squad_paths[:len(squad_paths)-1])
312
+ if natural_question_paths != None:
313
+ self.paths.extend(natural_question_paths[:len(natural_question_paths)-1])
314
+ if hotpotqa_paths != None:
315
+ self.paths.extend(hotpotqa_paths[:len(hotpotqa_paths)-1])
316
+ self.data = None
317
+ self.current_file = 0
318
+ self.remaining = 0
319
+ self.encodings = None
320
+ # tokenizer for strings
321
+ self.tokenizer = tokenizer
322
+
323
+
324
+ def __len__(self):
325
+ """
326
+ returns the length of the dataset
327
+ """
328
+ return len(self.paths)*1000
329
+
330
+ def read_file(self, path):
331
+ """
332
+ reads the file stored at path
333
+ """
334
+ with open(path, 'r', encoding='utf-8') as f:
335
+ lines = f.read().split('\n')
336
+ return lines
337
+
338
+ def get_encodings(self):
339
+ """
340
+ returns encoded strings for the model
341
+ """
342
+ # remove leading and ending whitespaces
343
+ questions = [q.strip() for q in self.data["question"]]
344
+ context = [q.lower() for q in self.data["context"]]
345
+
346
+ # tokenises questions and context. If the context is too long, we truncate it.
347
+ inputs = self.tokenizer(
348
+ questions,
349
+ context,
350
+ max_length=512,
351
+ truncation="only_second",
352
+ return_offsets_mapping=True,
353
+ padding="max_length",
354
+ )
355
+
356
+ # tuples of integers giving us the original positions
357
+ offset_mapping = inputs.pop("offset_mapping")
358
+
359
+ answers = self.data["answer"]
360
+ answer_start = self.data["answer_start"]
361
+
362
+ # store beginning and end positions
363
+ start_positions = []
364
+ end_positions = []
365
+
366
+ # iterate through questions
367
+ for i, offset in enumerate(offset_mapping):
368
+
369
+ answer = answers[i]
370
+ start_char = int(answer_start[i])
371
+ end_char = start_char + len(answer)
372
+
373
+ sequence_ids = inputs.sequence_ids(i)
374
+
375
+ # start and end of context based on tokens
376
+ idx = 0
377
+ while sequence_ids[idx] != 1:
378
+ idx += 1
379
+
380
+ context_start = idx
381
+ while sequence_ids[idx] == 1:
382
+ idx += 1
383
+ context_end = idx - 1
384
+
385
+ # If answer not inside context add (0,0)
386
+ if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
387
+ start_positions.append(0)
388
+ end_positions.append(0)
389
+ self.count += 1
390
+ else:
391
+ # go to first offset position that is smaller than start char
392
+ idx = context_start
393
+ while idx <= context_end and offset[idx][0] <= start_char:
394
+ idx += 1
395
+
396
+ start_positions.append(idx - 1)
397
+ idx = context_end
398
+ while idx >= context_start and offset[idx][1] >= end_char:
399
+ idx -= 1
400
+ end_positions.append(idx + 1)
401
+
402
+ # append start and end position to the embeddings
403
+ inputs["start_positions"] = start_positions
404
+ inputs["end_positions"] = end_positions
405
+ # return input_ids, attention mask, start and end positions (GT)
406
+ return {'input_ids': torch.tensor(inputs['input_ids']),
407
+ 'attention_mask': torch.tensor(inputs['attention_mask']),
408
+ 'start_positions': torch.tensor(inputs['start_positions']),
409
+ 'end_positions': torch.tensor(inputs['end_positions'])}
410
+
411
+ def __getitem__(self, i):
412
+ """
413
+ returns encoding of item i
414
+ """
415
+
416
+ # if we have looked at all items in the file - take next
417
+ if self.remaining == 0:
418
+ self.data = self.read_file(self.paths[self.current_file])
419
+ self.data = pd.DataFrame([line.split("\t") for line in self.data],
420
+ columns=["context", "question", "answer", "answer_start"])
421
+ self.current_file += 1
422
+ self.remaining = len(self.data)
423
+ self.encodings = self.get_encodings()
424
+ # if we are at the end of the dataset, start over again
425
+ if self.current_file == len(self.paths):
426
+ self.current_file = 0
427
+ self.remaining -= 1
428
+ return {key: tensor[i%1000] for key, tensor in self.encodings.items()}
429
+
430
+ def test_model(model, optim, test_ds_loader, device):
431
+ """
432
+ This function is used to test the model's functionality, namely if params are not NaN and infinite,
433
+ not-frozen parameters have to change, frozen ones must not
434
+ :param model: pytorch model to evaluate
435
+ :param optim: optimizer
436
+ :param test_ds_loader: dataloader object
437
+ :param device: device, the model is on
438
+ :raises Exception if the model doesn't work as expected
439
+ """
440
+ ## Check if non-frozen parameters changed and frozen ones did not
441
+
442
+ # get parameters used for tuning and store initial weight
443
+ params = [np for np in model.named_parameters() if np[1].requires_grad]
444
+ initial_params = [(name, p.clone()) for (name, p) in params]
445
+
446
+ # get frozen parameters and store initial weight
447
+ params_frozen = [np for np in model.named_parameters() if not np[1].requires_grad]
448
+ initial_params_frozen = [(name, p.clone()) for (name, p) in params_frozen]
449
+
450
+ # perform one iteration
451
+ optim.zero_grad()
452
+ batch = next(iter(test_ds_loader))
453
+
454
+ input_ids = batch['input_ids'].to(device)
455
+ attention_mask = batch['attention_mask'].to(device)
456
+ start_positions = batch['start_positions'].to(device)
457
+ end_positions = batch['end_positions'].to(device)
458
+
459
+ # forward pass and backpropagation
460
+ outputs = model(input_ids, attention_mask=attention_mask, start_positions=start_positions,
461
+ end_positions=end_positions)
462
+ loss = outputs['loss']
463
+ loss.backward()
464
+ optim.step()
465
+
466
+ # check if variables have changed
467
+ for (_, p0), (name, p1) in zip(initial_params, params):
468
+ # check different than initial
469
+ try:
470
+ assert not torch.equal(p0.to(device), p1.to(device))
471
+ except AssertionError:
472
+ raise Exception(
473
+ "{var_name} {msg}".format(
474
+ var_name=name,
475
+ msg='did not change!'
476
+ )
477
+ )
478
+ # check not NaN
479
+ try:
480
+ assert not torch.isnan(p1).byte().any()
481
+ except AssertionError:
482
+ raise Exception(
483
+ "{var_name} {msg}".format(
484
+ var_name=name,
485
+ msg='is NaN!'
486
+ )
487
+ )
488
+ # check finite
489
+ try:
490
+ assert torch.isfinite(p1).byte().all()
491
+ except AssertionError:
492
+ raise Exception(
493
+ "{var_name} {msg}".format(
494
+ var_name=name,
495
+ msg='is Inf!'
496
+ )
497
+ )
498
+
499
+ # check that frozen weights have not changed
500
+ for (_, p0), (name, p1) in zip(initial_params_frozen, params_frozen):
501
+ # should be the same
502
+ try:
503
+ assert torch.equal(p0.to(device), p1.to(device))
504
+ except AssertionError:
505
+ raise Exception(
506
+ "{var_name} {msg}".format(
507
+ var_name=name,
508
+ msg='changed!'
509
+ )
510
+ )
511
+ # check not NaN
512
+ try:
513
+ assert not torch.isnan(p1).byte().any()
514
+ except AssertionError:
515
+ raise Exception(
516
+ "{var_name} {msg}".format(
517
+ var_name=name,
518
+ msg='is NaN!'
519
+ )
520
+ )
521
+
522
+ # check finite numbers
523
+ try:
524
+ assert torch.isfinite(p1).byte().all()
525
+ except AssertionError:
526
+ raise Exception(
527
+ "{var_name} {msg}".format(
528
+ var_name=name,
529
+ msg='is Inf!'
530
+ )
531
+ )
532
+ print("Passed")
question_answering.ipynb ADDED
@@ -0,0 +1,2403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "19817716",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Question Answering\n",
9
+ "The following notebook contains different question answering models. We will start by introducing a representation for the dataset and corresponding DataLoader and then evaluate different models."
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": 50,
15
+ "id": "49bf46c6",
16
+ "metadata": {},
17
+ "outputs": [],
18
+ "source": [
19
+ "from transformers import DistilBertModel, DistilBertForMaskedLM, DistilBertConfig, \\\n",
20
+ " DistilBertTokenizerFast, AutoTokenizer, BertModel, BertForMaskedLM, BertTokenizerFast, BertConfig\n",
21
+ "from torch import nn\n",
22
+ "from pathlib import Path\n",
23
+ "import torch\n",
24
+ "import pandas as pd\n",
25
+ "from typing import Optional \n",
26
+ "from tqdm.auto import tqdm\n",
27
+ "from util import eval_test_set, count_parameters\n",
28
+ "from torch.optim import AdamW, RMSprop\n",
29
+ "\n",
30
+ "\n",
31
+ "from qa_model import QuestionDistilBERT, SimpleQuestionDistilBERT, ReuseQuestionDistilBERT, Dataset, test_model"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "markdown",
36
+ "id": "3ea47820",
37
+ "metadata": {},
38
+ "source": [
39
+ "## Data\n",
40
+ "Processing the data correctly is partly based on the Huggingface Tutorial (https://huggingface.co/course/chapter7/7?fw=pt)"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": 51,
46
+ "id": "7b1b2b3e",
47
+ "metadata": {},
48
+ "outputs": [],
49
+ "source": [
50
+ "tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": 52,
56
+ "id": "f276eba7",
57
+ "metadata": {
58
+ "scrolled": false
59
+ },
60
+ "outputs": [],
61
+ "source": [
62
+ " \n",
63
+ "# create datasets and loaders for training and test set\n",
64
+ "squad_paths = [str(x) for x in Path('data/training_squad/').glob('**/*.txt')]\n",
65
+ "nat_paths = [str(x) for x in Path('data/natural_questions_train/').glob('**/*.txt')]\n",
66
+ "hotpotqa_paths = [str(x) for x in Path('data/hotpotqa_training/').glob('**/*.txt')]"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "markdown",
71
+ "id": "ad8d532a",
72
+ "metadata": {},
73
+ "source": [
74
+ "## POC Model\n",
75
+ "* Works very well:\n",
76
+ " * Dropout 0.1 is too small (overfitting after first epoch) - changed to 0.15\n",
77
+ " * Difference between AdamW and RMSprop minimal\n",
78
+ " \n",
79
+ "### Results:\n",
80
+ "Dropout = 0.15\n",
81
+ "* Mean EM: 0.5374\n",
82
+ "* Mean F-1: 0.6826317532406944\n",
83
+ "\n",
84
+ "Dropout = 0.2 (overfitting realtively similar to first, but seems to be too high)\n",
85
+ "* Mean EM: 0.5044\n",
86
+ "* Mean F-1: 0.6437359169276439"
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "execution_count": 54,
92
+ "id": "703e7f38",
93
+ "metadata": {},
94
+ "outputs": [],
95
+ "source": [
96
+ "dataset = Dataset(squad_paths = squad_paths, natural_question_paths=None, hotpotqa_paths=hotpotqa_paths, tokenizer=tokenizer)\n",
97
+ "loader = torch.utils.data.DataLoader(dataset, batch_size=8)\n",
98
+ "\n",
99
+ "test_dataset = Dataset(squad_paths = [str(x) for x in Path('data/test_squad/').glob('**/*.txt')], \n",
100
+ " natural_question_paths=None, \n",
101
+ " hotpotqa_paths = None, tokenizer=tokenizer)\n",
102
+ "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4)"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "code",
107
+ "execution_count": 55,
108
+ "id": "6672f614",
109
+ "metadata": {},
110
+ "outputs": [],
111
+ "source": [
112
+ "model = DistilBertForMaskedLM.from_pretrained(\"distilbert-base-uncased\")\n",
113
+ "config = DistilBertConfig.from_pretrained(\"distilbert-base-uncased\")\n",
114
+ "mod = model.distilbert"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": 56,
120
+ "id": "dec15198",
121
+ "metadata": {},
122
+ "outputs": [
123
+ {
124
+ "data": {
125
+ "text/plain": [
126
+ "SimpleQuestionDistilBERT(\n",
127
+ " (distilbert): DistilBertModel(\n",
128
+ " (embeddings): Embeddings(\n",
129
+ " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
130
+ " (position_embeddings): Embedding(512, 768)\n",
131
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
132
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
133
+ " )\n",
134
+ " (transformer): Transformer(\n",
135
+ " (layer): ModuleList(\n",
136
+ " (0): TransformerBlock(\n",
137
+ " (attention): MultiHeadSelfAttention(\n",
138
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
139
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
140
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
141
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
142
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
143
+ " )\n",
144
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
145
+ " (ffn): FFN(\n",
146
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
147
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
148
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
149
+ " (activation): GELUActivation()\n",
150
+ " )\n",
151
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
152
+ " )\n",
153
+ " (1): TransformerBlock(\n",
154
+ " (attention): MultiHeadSelfAttention(\n",
155
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
156
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
157
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
158
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
159
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
160
+ " )\n",
161
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
162
+ " (ffn): FFN(\n",
163
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
164
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
165
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
166
+ " (activation): GELUActivation()\n",
167
+ " )\n",
168
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
169
+ " )\n",
170
+ " (2): TransformerBlock(\n",
171
+ " (attention): MultiHeadSelfAttention(\n",
172
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
173
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
174
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
175
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
176
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
177
+ " )\n",
178
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
179
+ " (ffn): FFN(\n",
180
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
181
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
182
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
183
+ " (activation): GELUActivation()\n",
184
+ " )\n",
185
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
186
+ " )\n",
187
+ " (3): TransformerBlock(\n",
188
+ " (attention): MultiHeadSelfAttention(\n",
189
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
190
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
191
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
192
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
193
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
194
+ " )\n",
195
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
196
+ " (ffn): FFN(\n",
197
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
198
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
199
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
200
+ " (activation): GELUActivation()\n",
201
+ " )\n",
202
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
203
+ " )\n",
204
+ " (4): TransformerBlock(\n",
205
+ " (attention): MultiHeadSelfAttention(\n",
206
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
207
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
208
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
209
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
210
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
211
+ " )\n",
212
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
213
+ " (ffn): FFN(\n",
214
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
215
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
216
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
217
+ " (activation): GELUActivation()\n",
218
+ " )\n",
219
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
220
+ " )\n",
221
+ " (5): TransformerBlock(\n",
222
+ " (attention): MultiHeadSelfAttention(\n",
223
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
224
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
225
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
226
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
227
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
228
+ " )\n",
229
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
230
+ " (ffn): FFN(\n",
231
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
232
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
233
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
234
+ " (activation): GELUActivation()\n",
235
+ " )\n",
236
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
237
+ " )\n",
238
+ " )\n",
239
+ " )\n",
240
+ " )\n",
241
+ " (dropout): Dropout(p=0.5, inplace=False)\n",
242
+ " (classifier): Linear(in_features=768, out_features=2, bias=True)\n",
243
+ ")"
244
+ ]
245
+ },
246
+ "execution_count": 56,
247
+ "metadata": {},
248
+ "output_type": "execute_result"
249
+ }
250
+ ],
251
+ "source": [
252
+ "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
253
+ "model = SimpleQuestionDistilBERT(mod)\n",
254
+ "model.to(device)"
255
+ ]
256
+ },
257
+ {
258
+ "cell_type": "code",
259
+ "execution_count": 57,
260
+ "id": "9def3c83",
261
+ "metadata": {},
262
+ "outputs": [
263
+ {
264
+ "name": "stdout",
265
+ "output_type": "stream",
266
+ "text": [
267
+ "+---------------------------------------------------------+------------+\n",
268
+ "| Modules | Parameters |\n",
269
+ "+---------------------------------------------------------+------------+\n",
270
+ "| distilbert.embeddings.word_embeddings.weight | 23440896 |\n",
271
+ "| distilbert.embeddings.position_embeddings.weight | 393216 |\n",
272
+ "| distilbert.embeddings.LayerNorm.weight | 768 |\n",
273
+ "| distilbert.embeddings.LayerNorm.bias | 768 |\n",
274
+ "| distilbert.transformer.layer.0.attention.q_lin.weight | 589824 |\n",
275
+ "| distilbert.transformer.layer.0.attention.q_lin.bias | 768 |\n",
276
+ "| distilbert.transformer.layer.0.attention.k_lin.weight | 589824 |\n",
277
+ "| distilbert.transformer.layer.0.attention.k_lin.bias | 768 |\n",
278
+ "| distilbert.transformer.layer.0.attention.v_lin.weight | 589824 |\n",
279
+ "| distilbert.transformer.layer.0.attention.v_lin.bias | 768 |\n",
280
+ "| distilbert.transformer.layer.0.attention.out_lin.weight | 589824 |\n",
281
+ "| distilbert.transformer.layer.0.attention.out_lin.bias | 768 |\n",
282
+ "| distilbert.transformer.layer.0.sa_layer_norm.weight | 768 |\n",
283
+ "| distilbert.transformer.layer.0.sa_layer_norm.bias | 768 |\n",
284
+ "| distilbert.transformer.layer.0.ffn.lin1.weight | 2359296 |\n",
285
+ "| distilbert.transformer.layer.0.ffn.lin1.bias | 3072 |\n",
286
+ "| distilbert.transformer.layer.0.ffn.lin2.weight | 2359296 |\n",
287
+ "| distilbert.transformer.layer.0.ffn.lin2.bias | 768 |\n",
288
+ "| distilbert.transformer.layer.0.output_layer_norm.weight | 768 |\n",
289
+ "| distilbert.transformer.layer.0.output_layer_norm.bias | 768 |\n",
290
+ "| distilbert.transformer.layer.1.attention.q_lin.weight | 589824 |\n",
291
+ "| distilbert.transformer.layer.1.attention.q_lin.bias | 768 |\n",
292
+ "| distilbert.transformer.layer.1.attention.k_lin.weight | 589824 |\n",
293
+ "| distilbert.transformer.layer.1.attention.k_lin.bias | 768 |\n",
294
+ "| distilbert.transformer.layer.1.attention.v_lin.weight | 589824 |\n",
295
+ "| distilbert.transformer.layer.1.attention.v_lin.bias | 768 |\n",
296
+ "| distilbert.transformer.layer.1.attention.out_lin.weight | 589824 |\n",
297
+ "| distilbert.transformer.layer.1.attention.out_lin.bias | 768 |\n",
298
+ "| distilbert.transformer.layer.1.sa_layer_norm.weight | 768 |\n",
299
+ "| distilbert.transformer.layer.1.sa_layer_norm.bias | 768 |\n",
300
+ "| distilbert.transformer.layer.1.ffn.lin1.weight | 2359296 |\n",
301
+ "| distilbert.transformer.layer.1.ffn.lin1.bias | 3072 |\n",
302
+ "| distilbert.transformer.layer.1.ffn.lin2.weight | 2359296 |\n",
303
+ "| distilbert.transformer.layer.1.ffn.lin2.bias | 768 |\n",
304
+ "| distilbert.transformer.layer.1.output_layer_norm.weight | 768 |\n",
305
+ "| distilbert.transformer.layer.1.output_layer_norm.bias | 768 |\n",
306
+ "| distilbert.transformer.layer.2.attention.q_lin.weight | 589824 |\n",
307
+ "| distilbert.transformer.layer.2.attention.q_lin.bias | 768 |\n",
308
+ "| distilbert.transformer.layer.2.attention.k_lin.weight | 589824 |\n",
309
+ "| distilbert.transformer.layer.2.attention.k_lin.bias | 768 |\n",
310
+ "| distilbert.transformer.layer.2.attention.v_lin.weight | 589824 |\n",
311
+ "| distilbert.transformer.layer.2.attention.v_lin.bias | 768 |\n",
312
+ "| distilbert.transformer.layer.2.attention.out_lin.weight | 589824 |\n",
313
+ "| distilbert.transformer.layer.2.attention.out_lin.bias | 768 |\n",
314
+ "| distilbert.transformer.layer.2.sa_layer_norm.weight | 768 |\n",
315
+ "| distilbert.transformer.layer.2.sa_layer_norm.bias | 768 |\n",
316
+ "| distilbert.transformer.layer.2.ffn.lin1.weight | 2359296 |\n",
317
+ "| distilbert.transformer.layer.2.ffn.lin1.bias | 3072 |\n",
318
+ "| distilbert.transformer.layer.2.ffn.lin2.weight | 2359296 |\n",
319
+ "| distilbert.transformer.layer.2.ffn.lin2.bias | 768 |\n",
320
+ "| distilbert.transformer.layer.2.output_layer_norm.weight | 768 |\n",
321
+ "| distilbert.transformer.layer.2.output_layer_norm.bias | 768 |\n",
322
+ "| distilbert.transformer.layer.3.attention.q_lin.weight | 589824 |\n",
323
+ "| distilbert.transformer.layer.3.attention.q_lin.bias | 768 |\n",
324
+ "| distilbert.transformer.layer.3.attention.k_lin.weight | 589824 |\n",
325
+ "| distilbert.transformer.layer.3.attention.k_lin.bias | 768 |\n",
326
+ "| distilbert.transformer.layer.3.attention.v_lin.weight | 589824 |\n",
327
+ "| distilbert.transformer.layer.3.attention.v_lin.bias | 768 |\n",
328
+ "| distilbert.transformer.layer.3.attention.out_lin.weight | 589824 |\n",
329
+ "| distilbert.transformer.layer.3.attention.out_lin.bias | 768 |\n",
330
+ "| distilbert.transformer.layer.3.sa_layer_norm.weight | 768 |\n",
331
+ "| distilbert.transformer.layer.3.sa_layer_norm.bias | 768 |\n",
332
+ "| distilbert.transformer.layer.3.ffn.lin1.weight | 2359296 |\n",
333
+ "| distilbert.transformer.layer.3.ffn.lin1.bias | 3072 |\n",
334
+ "| distilbert.transformer.layer.3.ffn.lin2.weight | 2359296 |\n",
335
+ "| distilbert.transformer.layer.3.ffn.lin2.bias | 768 |\n",
336
+ "| distilbert.transformer.layer.3.output_layer_norm.weight | 768 |\n",
337
+ "| distilbert.transformer.layer.3.output_layer_norm.bias | 768 |\n",
338
+ "| distilbert.transformer.layer.4.attention.q_lin.weight | 589824 |\n",
339
+ "| distilbert.transformer.layer.4.attention.q_lin.bias | 768 |\n",
340
+ "| distilbert.transformer.layer.4.attention.k_lin.weight | 589824 |\n",
341
+ "| distilbert.transformer.layer.4.attention.k_lin.bias | 768 |\n",
342
+ "| distilbert.transformer.layer.4.attention.v_lin.weight | 589824 |\n",
343
+ "| distilbert.transformer.layer.4.attention.v_lin.bias | 768 |\n",
344
+ "| distilbert.transformer.layer.4.attention.out_lin.weight | 589824 |\n",
345
+ "| distilbert.transformer.layer.4.attention.out_lin.bias | 768 |\n",
346
+ "| distilbert.transformer.layer.4.sa_layer_norm.weight | 768 |\n",
347
+ "| distilbert.transformer.layer.4.sa_layer_norm.bias | 768 |\n",
348
+ "| distilbert.transformer.layer.4.ffn.lin1.weight | 2359296 |\n",
349
+ "| distilbert.transformer.layer.4.ffn.lin1.bias | 3072 |\n",
350
+ "| distilbert.transformer.layer.4.ffn.lin2.weight | 2359296 |\n",
351
+ "| distilbert.transformer.layer.4.ffn.lin2.bias | 768 |\n",
352
+ "| distilbert.transformer.layer.4.output_layer_norm.weight | 768 |\n",
353
+ "| distilbert.transformer.layer.4.output_layer_norm.bias | 768 |\n",
354
+ "| distilbert.transformer.layer.5.attention.q_lin.weight | 589824 |\n",
355
+ "| distilbert.transformer.layer.5.attention.q_lin.bias | 768 |\n",
356
+ "| distilbert.transformer.layer.5.attention.k_lin.weight | 589824 |\n",
357
+ "| distilbert.transformer.layer.5.attention.k_lin.bias | 768 |\n",
358
+ "| distilbert.transformer.layer.5.attention.v_lin.weight | 589824 |\n",
359
+ "| distilbert.transformer.layer.5.attention.v_lin.bias | 768 |\n",
360
+ "| distilbert.transformer.layer.5.attention.out_lin.weight | 589824 |\n",
361
+ "| distilbert.transformer.layer.5.attention.out_lin.bias | 768 |\n",
362
+ "| distilbert.transformer.layer.5.sa_layer_norm.weight | 768 |\n",
363
+ "| distilbert.transformer.layer.5.sa_layer_norm.bias | 768 |\n",
364
+ "| distilbert.transformer.layer.5.ffn.lin1.weight | 2359296 |\n",
365
+ "| distilbert.transformer.layer.5.ffn.lin1.bias | 3072 |\n",
366
+ "| distilbert.transformer.layer.5.ffn.lin2.weight | 2359296 |\n",
367
+ "| distilbert.transformer.layer.5.ffn.lin2.bias | 768 |\n",
368
+ "| distilbert.transformer.layer.5.output_layer_norm.weight | 768 |\n",
369
+ "| distilbert.transformer.layer.5.output_layer_norm.bias | 768 |\n",
370
+ "| classifier.weight | 1536 |\n",
371
+ "| classifier.bias | 2 |\n",
372
+ "+---------------------------------------------------------+------------+\n",
373
+ "Total Trainable Params: 66364418\n"
374
+ ]
375
+ },
376
+ {
377
+ "data": {
378
+ "text/plain": [
379
+ "66364418"
380
+ ]
381
+ },
382
+ "execution_count": 57,
383
+ "metadata": {},
384
+ "output_type": "execute_result"
385
+ }
386
+ ],
387
+ "source": [
388
+ "count_parameters(model)"
389
+ ]
390
+ },
391
+ {
392
+ "cell_type": "markdown",
393
+ "id": "426a6311",
394
+ "metadata": {},
395
+ "source": [
396
+ "### Testing the model"
397
+ ]
398
+ },
399
+ {
400
+ "cell_type": "code",
401
+ "execution_count": 58,
402
+ "id": "6151c201",
403
+ "metadata": {},
404
+ "outputs": [],
405
+ "source": [
406
+ "# get smaller dataset\n",
407
+ "batch_size = 8\n",
408
+ "test_ds = Dataset(squad_paths = squad_paths[:2], natural_question_paths=None, hotpotqa_paths=None, tokenizer=tokenizer)\n",
409
+ "test_ds_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)\n",
410
+ "optim = RMSprop(model.parameters(), lr=1e-4)"
411
+ ]
412
+ },
413
+ {
414
+ "cell_type": "code",
415
+ "execution_count": 59,
416
+ "id": "aeae0c56",
417
+ "metadata": {},
418
+ "outputs": [
419
+ {
420
+ "name": "stdout",
421
+ "output_type": "stream",
422
+ "text": [
423
+ "Passed\n"
424
+ ]
425
+ }
426
+ ],
427
+ "source": [
428
+ "test_model(model, optim, test_ds_loader, device)"
429
+ ]
430
+ },
431
+ {
432
+ "cell_type": "markdown",
433
+ "id": "59928d34",
434
+ "metadata": {},
435
+ "source": [
436
+ "### Model Training"
437
+ ]
438
+ },
439
+ {
440
+ "cell_type": "code",
441
+ "execution_count": 60,
442
+ "id": "a8017b8c",
443
+ "metadata": {},
444
+ "outputs": [
445
+ {
446
+ "data": {
447
+ "text/plain": [
448
+ "SimpleQuestionDistilBERT(\n",
449
+ " (distilbert): DistilBertModel(\n",
450
+ " (embeddings): Embeddings(\n",
451
+ " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
452
+ " (position_embeddings): Embedding(512, 768)\n",
453
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
454
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
455
+ " )\n",
456
+ " (transformer): Transformer(\n",
457
+ " (layer): ModuleList(\n",
458
+ " (0): TransformerBlock(\n",
459
+ " (attention): MultiHeadSelfAttention(\n",
460
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
461
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
462
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
463
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
464
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
465
+ " )\n",
466
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
467
+ " (ffn): FFN(\n",
468
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
469
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
470
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
471
+ " (activation): GELUActivation()\n",
472
+ " )\n",
473
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
474
+ " )\n",
475
+ " (1): TransformerBlock(\n",
476
+ " (attention): MultiHeadSelfAttention(\n",
477
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
478
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
479
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
480
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
481
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
482
+ " )\n",
483
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
484
+ " (ffn): FFN(\n",
485
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
486
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
487
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
488
+ " (activation): GELUActivation()\n",
489
+ " )\n",
490
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
491
+ " )\n",
492
+ " (2): TransformerBlock(\n",
493
+ " (attention): MultiHeadSelfAttention(\n",
494
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
495
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
496
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
497
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
498
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
499
+ " )\n",
500
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
501
+ " (ffn): FFN(\n",
502
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
503
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
504
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
505
+ " (activation): GELUActivation()\n",
506
+ " )\n",
507
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
508
+ " )\n",
509
+ " (3): TransformerBlock(\n",
510
+ " (attention): MultiHeadSelfAttention(\n",
511
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
512
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
513
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
514
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
515
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
516
+ " )\n",
517
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
518
+ " (ffn): FFN(\n",
519
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
520
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
521
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
522
+ " (activation): GELUActivation()\n",
523
+ " )\n",
524
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
525
+ " )\n",
526
+ " (4): TransformerBlock(\n",
527
+ " (attention): MultiHeadSelfAttention(\n",
528
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
529
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
530
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
531
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
532
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
533
+ " )\n",
534
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
535
+ " (ffn): FFN(\n",
536
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
537
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
538
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
539
+ " (activation): GELUActivation()\n",
540
+ " )\n",
541
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
542
+ " )\n",
543
+ " (5): TransformerBlock(\n",
544
+ " (attention): MultiHeadSelfAttention(\n",
545
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
546
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
547
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
548
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
549
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
550
+ " )\n",
551
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
552
+ " (ffn): FFN(\n",
553
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
554
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
555
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
556
+ " (activation): GELUActivation()\n",
557
+ " )\n",
558
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
559
+ " )\n",
560
+ " )\n",
561
+ " )\n",
562
+ " )\n",
563
+ " (dropout): Dropout(p=0.5, inplace=False)\n",
564
+ " (classifier): Linear(in_features=768, out_features=2, bias=True)\n",
565
+ ")"
566
+ ]
567
+ },
568
+ "execution_count": 60,
569
+ "metadata": {},
570
+ "output_type": "execute_result"
571
+ }
572
+ ],
573
+ "source": [
574
+ "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
575
+ "model = SimpleQuestionDistilBERT(mod)\n",
576
+ "model.to(device)"
577
+ ]
578
+ },
579
+ {
580
+ "cell_type": "code",
581
+ "execution_count": 61,
582
+ "id": "f13c12dc",
583
+ "metadata": {},
584
+ "outputs": [],
585
+ "source": [
586
+ "model.train()\n",
587
+ "optim = RMSprop(model.parameters(), lr=1e-4)"
588
+ ]
589
+ },
590
+ {
591
+ "cell_type": "code",
592
+ "execution_count": 22,
593
+ "id": "e4fa54d9",
594
+ "metadata": {},
595
+ "outputs": [
596
+ {
597
+ "data": {
598
+ "application/vnd.jupyter.widget-view+json": {
599
+ "model_id": "0016d9f5ba764eb98e9df8573995c86c",
600
+ "version_major": 2,
601
+ "version_minor": 0
602
+ },
603
+ "text/plain": [
604
+ " 0%| | 0/10875 [00:00<?, ?it/s]"
605
+ ]
606
+ },
607
+ "metadata": {},
608
+ "output_type": "display_data"
609
+ },
610
+ {
611
+ "name": "stdout",
612
+ "output_type": "stream",
613
+ "text": [
614
+ "Mean Training Error 0.7555404769408292\n"
615
+ ]
616
+ },
617
+ {
618
+ "data": {
619
+ "application/vnd.jupyter.widget-view+json": {
620
+ "model_id": "96af0e22e2ee44fd920795b0e7317839",
621
+ "version_major": 2,
622
+ "version_minor": 0
623
+ },
624
+ "text/plain": [
625
+ " 0%| | 0/2500 [00:00<?, ?it/s]"
626
+ ]
627
+ },
628
+ "metadata": {},
629
+ "output_type": "display_data"
630
+ },
631
+ {
632
+ "name": "stdout",
633
+ "output_type": "stream",
634
+ "text": [
635
+ "Mean Test Error 1.761920437876694\n"
636
+ ]
637
+ },
638
+ {
639
+ "data": {
640
+ "application/vnd.jupyter.widget-view+json": {
641
+ "model_id": "5160ffe5f60e4b72b46746a33b1d60d0",
642
+ "version_major": 2,
643
+ "version_minor": 0
644
+ },
645
+ "text/plain": [
646
+ " 0%| | 0/10875 [00:00<?, ?it/s]"
647
+ ]
648
+ },
649
+ "metadata": {},
650
+ "output_type": "display_data"
651
+ },
652
+ {
653
+ "ename": "KeyboardInterrupt",
654
+ "evalue": "",
655
+ "output_type": "error",
656
+ "traceback": [
657
+ "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
658
+ "\u001B[0;31mKeyboardInterrupt\u001B[0m Traceback (most recent call last)",
659
+ "Cell \u001B[0;32mIn [22], line 18\u001B[0m\n\u001B[1;32m 16\u001B[0m \u001B[38;5;66;03m# print(torch.argmax(outputs['start_logits'],axis=1), torch.argmax(outputs['end_logits'], axis=1), start, end)\u001B[39;00m\n\u001B[1;32m 17\u001B[0m loss \u001B[38;5;241m=\u001B[39m outputs[\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mloss\u001B[39m\u001B[38;5;124m'\u001B[39m]\n\u001B[0;32m---> 18\u001B[0m loss\u001B[38;5;241m.\u001B[39mbackward()\n\u001B[1;32m 19\u001B[0m \u001B[38;5;66;03m# torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)\u001B[39;00m\n\u001B[1;32m 20\u001B[0m optim\u001B[38;5;241m.\u001B[39mstep()\n",
660
+ "File \u001B[0;32m~/Documents/University/WS2022/applieddl/venv/lib64/python3.10/site-packages/torch/_tensor.py:396\u001B[0m, in \u001B[0;36mTensor.backward\u001B[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001B[0m\n\u001B[1;32m 387\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m has_torch_function_unary(\u001B[38;5;28mself\u001B[39m):\n\u001B[1;32m 388\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m handle_torch_function(\n\u001B[1;32m 389\u001B[0m Tensor\u001B[38;5;241m.\u001B[39mbackward,\n\u001B[1;32m 390\u001B[0m (\u001B[38;5;28mself\u001B[39m,),\n\u001B[0;32m (...)\u001B[0m\n\u001B[1;32m 394\u001B[0m create_graph\u001B[38;5;241m=\u001B[39mcreate_graph,\n\u001B[1;32m 395\u001B[0m inputs\u001B[38;5;241m=\u001B[39minputs)\n\u001B[0;32m--> 396\u001B[0m \u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mautograd\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbackward\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mgradient\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mretain_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcreate_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43minputs\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43minputs\u001B[49m\u001B[43m)\u001B[49m\n",
661
+ "File \u001B[0;32m~/Documents/University/WS2022/applieddl/venv/lib64/python3.10/site-packages/torch/autograd/__init__.py:173\u001B[0m, in \u001B[0;36mbackward\u001B[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001B[0m\n\u001B[1;32m 168\u001B[0m retain_graph \u001B[38;5;241m=\u001B[39m create_graph\n\u001B[1;32m 170\u001B[0m \u001B[38;5;66;03m# The reason we repeat same the comment below is that\u001B[39;00m\n\u001B[1;32m 171\u001B[0m \u001B[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001B[39;00m\n\u001B[1;32m 172\u001B[0m \u001B[38;5;66;03m# calls in the traceback and some print out the last line\u001B[39;00m\n\u001B[0;32m--> 173\u001B[0m \u001B[43mVariable\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_execution_engine\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mrun_backward\u001B[49m\u001B[43m(\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001B[39;49;00m\n\u001B[1;32m 174\u001B[0m \u001B[43m \u001B[49m\u001B[43mtensors\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mgrad_tensors_\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mretain_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcreate_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43minputs\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 175\u001B[0m \u001B[43m \u001B[49m\u001B[43mallow_unreachable\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mTrue\u001B[39;49;00m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43maccumulate_grad\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mTrue\u001B[39;49;00m\u001B[43m)\u001B[49m\n",
662
+ "\u001B[0;31mKeyboardInterrupt\u001B[0m: "
663
+ ]
664
+ }
665
+ ],
666
+ "source": [
667
+ "epochs = 5\n",
668
+ "\n",
669
+ "for epoch in range(epochs):\n",
670
+ " loop = tqdm(loader, leave=True)\n",
671
+ " model.train()\n",
672
+ " mean_training_error = []\n",
673
+ " for batch in loop:\n",
674
+ " optim.zero_grad()\n",
675
+ " \n",
676
+ " input_ids = batch['input_ids'].to(device)\n",
677
+ " attention_mask = batch['attention_mask'].to(device)\n",
678
+ " start = batch['start_positions'].to(device)\n",
679
+ " end = batch['end_positions'].to(device)\n",
680
+ " \n",
681
+ " outputs = model(input_ids, attention_mask=attention_mask, start_positions=start, end_positions=end)\n",
682
+ " # print(torch.argmax(outputs['start_logits'],axis=1), torch.argmax(outputs['end_logits'], axis=1), start, end)\n",
683
+ " loss = outputs['loss']\n",
684
+ " loss.backward()\n",
685
+ " # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)\n",
686
+ " optim.step()\n",
687
+ " mean_training_error.append(loss.item())\n",
688
+ " loop.set_description(f'Epoch {epoch}')\n",
689
+ " loop.set_postfix(loss=loss.item())\n",
690
+ " print(\"Mean Training Error\", np.mean(mean_training_error))\n",
691
+ " \n",
692
+ " \n",
693
+ " loop = tqdm(test_loader, leave=True)\n",
694
+ " model.eval()\n",
695
+ " mean_test_error = []\n",
696
+ " for batch in loop:\n",
697
+ " \n",
698
+ " input_ids = batch['input_ids'].to(device)\n",
699
+ " attention_mask = batch['attention_mask'].to(device)\n",
700
+ " start = batch['start_positions'].to(device)\n",
701
+ " end = batch['end_positions'].to(device)\n",
702
+ " \n",
703
+ " outputs = model(input_ids, attention_mask=attention_mask, start_positions=start, end_positions=end)\n",
704
+ " # print(torch.argmax(outputs['start_logits'],axis=1), torch.argmax(outputs['end_logits'], axis=1), start, end)\n",
705
+ " loss = outputs['loss']\n",
706
+ " \n",
707
+ " mean_test_error.append(loss.item())\n",
708
+ " loop.set_description(f'Epoch {epoch} Testset')\n",
709
+ " loop.set_postfix(loss=loss.item())\n",
710
+ " print(\"Mean Test Error\", np.mean(mean_test_error))"
711
+ ]
712
+ },
713
+ {
714
+ "cell_type": "code",
715
+ "execution_count": 19,
716
+ "id": "6ff26fb4",
717
+ "metadata": {},
718
+ "outputs": [],
719
+ "source": [
720
+ "torch.save(model.state_dict(), \"simple_distilbert_qa.model\")"
721
+ ]
722
+ },
723
+ {
724
+ "cell_type": "code",
725
+ "execution_count": 20,
726
+ "id": "a5e7abeb",
727
+ "metadata": {},
728
+ "outputs": [
729
+ {
730
+ "data": {
731
+ "text/plain": [
732
+ "<All keys matched successfully>"
733
+ ]
734
+ },
735
+ "execution_count": 20,
736
+ "metadata": {},
737
+ "output_type": "execute_result"
738
+ }
739
+ ],
740
+ "source": [
741
+ "model = SimpleQuestionDistilBERT(mod)\n",
742
+ "model.load_state_dict(torch.load(\"simple_distilbert_qa.model\"))"
743
+ ]
744
+ },
745
+ {
746
+ "cell_type": "code",
747
+ "execution_count": 18,
748
+ "id": "f5ad7bee",
749
+ "metadata": {},
750
+ "outputs": [
751
+ {
752
+ "name": "stderr",
753
+ "output_type": "stream",
754
+ "text": [
755
+ "100%|鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅| 2500/2500 [02:09<00:00, 19.37it/s]"
756
+ ]
757
+ },
758
+ {
759
+ "name": "stdout",
760
+ "output_type": "stream",
761
+ "text": [
762
+ "Mean EM: 0.5374\n",
763
+ "Mean F-1: 0.6826317532406944\n"
764
+ ]
765
+ },
766
+ {
767
+ "name": "stderr",
768
+ "output_type": "stream",
769
+ "text": [
770
+ "\n"
771
+ ]
772
+ }
773
+ ],
774
+ "source": [
775
+ "eval_test_set(model, tokenizer, test_loader, device)"
776
+ ]
777
+ },
778
+ {
779
+ "cell_type": "markdown",
780
+ "id": "fa6017a8",
781
+ "metadata": {},
782
+ "source": [
783
+ "## Freeze baseline and train new head\n",
784
+ "This was my initial idea, to freeze the layers and add a completely new head, which we train from scratch. I tried a lot of different configurations, but nothing really worked, I usually stayed at a CrossEntropyLoss of about 3 the whole time. Below, you can see the different heads I have tried.\n",
785
+ "\n",
786
+ "Furthermore, I experimented with different data, because I though it might not be enough data all in all. I would conclude that this didn't work because (1) Transformers are very data-hungry and I probably still used too little data (one epoch took about 1h though, so it wasn't possible to use even more). (2) We train the layers completely new, which means they contain absolutely no structure about the problem and task beforehand. I do not think that this way of training leads to better results / less energy used all in all, because it would be too resource intense.\n",
787
+ "\n",
788
+ "The following setup is partly based on the HuggingFace implementation of the question answering model (https://github.com/huggingface/transformers/blob/v4.23.1/src/transformers/models/distilbert/modeling_distilbert.py#L805)"
789
+ ]
790
+ },
791
+ {
792
+ "cell_type": "code",
793
+ "execution_count": 62,
794
+ "id": "92b21967",
795
+ "metadata": {},
796
+ "outputs": [],
797
+ "source": [
798
+ "model = DistilBertForMaskedLM.from_pretrained(\"distilbert-base-uncased\")"
799
+ ]
800
+ },
801
+ {
802
+ "cell_type": "code",
803
+ "execution_count": 63,
804
+ "id": "1d7b3a8c",
805
+ "metadata": {},
806
+ "outputs": [],
807
+ "source": [
808
+ "config = DistilBertConfig.from_pretrained(\"distilbert-base-uncased\")"
809
+ ]
810
+ },
811
+ {
812
+ "cell_type": "code",
813
+ "execution_count": 64,
814
+ "id": "91444894",
815
+ "metadata": {},
816
+ "outputs": [],
817
+ "source": [
818
+ "# only take base model, we do not need the classification head\n",
819
+ "mod = model.distilbert"
820
+ ]
821
+ },
822
+ {
823
+ "cell_type": "code",
824
+ "execution_count": 65,
825
+ "id": "74ca6c07",
826
+ "metadata": {},
827
+ "outputs": [
828
+ {
829
+ "data": {
830
+ "text/plain": [
831
+ "QuestionDistilBERT(\n",
832
+ " (distilbert): DistilBertModel(\n",
833
+ " (embeddings): Embeddings(\n",
834
+ " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
835
+ " (position_embeddings): Embedding(512, 768)\n",
836
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
837
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
838
+ " )\n",
839
+ " (transformer): Transformer(\n",
840
+ " (layer): ModuleList(\n",
841
+ " (0): TransformerBlock(\n",
842
+ " (attention): MultiHeadSelfAttention(\n",
843
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
844
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
845
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
846
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
847
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
848
+ " )\n",
849
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
850
+ " (ffn): FFN(\n",
851
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
852
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
853
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
854
+ " (activation): GELUActivation()\n",
855
+ " )\n",
856
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
857
+ " )\n",
858
+ " (1): TransformerBlock(\n",
859
+ " (attention): MultiHeadSelfAttention(\n",
860
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
861
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
862
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
863
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
864
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
865
+ " )\n",
866
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
867
+ " (ffn): FFN(\n",
868
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
869
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
870
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
871
+ " (activation): GELUActivation()\n",
872
+ " )\n",
873
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
874
+ " )\n",
875
+ " (2): TransformerBlock(\n",
876
+ " (attention): MultiHeadSelfAttention(\n",
877
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
878
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
879
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
880
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
881
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
882
+ " )\n",
883
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
884
+ " (ffn): FFN(\n",
885
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
886
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
887
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
888
+ " (activation): GELUActivation()\n",
889
+ " )\n",
890
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
891
+ " )\n",
892
+ " (3): TransformerBlock(\n",
893
+ " (attention): MultiHeadSelfAttention(\n",
894
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
895
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
896
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
897
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
898
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
899
+ " )\n",
900
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
901
+ " (ffn): FFN(\n",
902
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
903
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
904
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
905
+ " (activation): GELUActivation()\n",
906
+ " )\n",
907
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
908
+ " )\n",
909
+ " (4): TransformerBlock(\n",
910
+ " (attention): MultiHeadSelfAttention(\n",
911
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
912
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
913
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
914
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
915
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
916
+ " )\n",
917
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
918
+ " (ffn): FFN(\n",
919
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
920
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
921
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
922
+ " (activation): GELUActivation()\n",
923
+ " )\n",
924
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
925
+ " )\n",
926
+ " (5): TransformerBlock(\n",
927
+ " (attention): MultiHeadSelfAttention(\n",
928
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
929
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
930
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
931
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
932
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
933
+ " )\n",
934
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
935
+ " (ffn): FFN(\n",
936
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
937
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
938
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
939
+ " (activation): GELUActivation()\n",
940
+ " )\n",
941
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
942
+ " )\n",
943
+ " )\n",
944
+ " )\n",
945
+ " )\n",
946
+ " (relu): ReLU()\n",
947
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
948
+ " (te): TransformerEncoder(\n",
949
+ " (layers): ModuleList(\n",
950
+ " (0): TransformerEncoderLayer(\n",
951
+ " (self_attn): MultiheadAttention(\n",
952
+ " (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n",
953
+ " )\n",
954
+ " (linear1): Linear(in_features=768, out_features=2048, bias=True)\n",
955
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
956
+ " (linear2): Linear(in_features=2048, out_features=768, bias=True)\n",
957
+ " (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
958
+ " (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
959
+ " (dropout1): Dropout(p=0.1, inplace=False)\n",
960
+ " (dropout2): Dropout(p=0.1, inplace=False)\n",
961
+ " )\n",
962
+ " (1): TransformerEncoderLayer(\n",
963
+ " (self_attn): MultiheadAttention(\n",
964
+ " (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n",
965
+ " )\n",
966
+ " (linear1): Linear(in_features=768, out_features=2048, bias=True)\n",
967
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
968
+ " (linear2): Linear(in_features=2048, out_features=768, bias=True)\n",
969
+ " (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
970
+ " (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
971
+ " (dropout1): Dropout(p=0.1, inplace=False)\n",
972
+ " (dropout2): Dropout(p=0.1, inplace=False)\n",
973
+ " )\n",
974
+ " (2): TransformerEncoderLayer(\n",
975
+ " (self_attn): MultiheadAttention(\n",
976
+ " (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n",
977
+ " )\n",
978
+ " (linear1): Linear(in_features=768, out_features=2048, bias=True)\n",
979
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
980
+ " (linear2): Linear(in_features=2048, out_features=768, bias=True)\n",
981
+ " (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
982
+ " (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
983
+ " (dropout1): Dropout(p=0.1, inplace=False)\n",
984
+ " (dropout2): Dropout(p=0.1, inplace=False)\n",
985
+ " )\n",
986
+ " )\n",
987
+ " )\n",
988
+ " (classifier): Sequential(\n",
989
+ " (0): Dropout(p=0.1, inplace=False)\n",
990
+ " (1): ReLU()\n",
991
+ " (2): Linear(in_features=768, out_features=512, bias=True)\n",
992
+ " (3): Dropout(p=0.1, inplace=False)\n",
993
+ " (4): ReLU()\n",
994
+ " (5): Linear(in_features=512, out_features=256, bias=True)\n",
995
+ " (6): Dropout(p=0.1, inplace=False)\n",
996
+ " (7): ReLU()\n",
997
+ " (8): Linear(in_features=256, out_features=128, bias=True)\n",
998
+ " (9): Dropout(p=0.1, inplace=False)\n",
999
+ " (10): ReLU()\n",
1000
+ " (11): Linear(in_features=128, out_features=64, bias=True)\n",
1001
+ " (12): Dropout(p=0.1, inplace=False)\n",
1002
+ " (13): ReLU()\n",
1003
+ " (14): Linear(in_features=64, out_features=2, bias=True)\n",
1004
+ " )\n",
1005
+ ")"
1006
+ ]
1007
+ },
1008
+ "execution_count": 65,
1009
+ "metadata": {},
1010
+ "output_type": "execute_result"
1011
+ }
1012
+ ],
1013
+ "source": [
1014
+ "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
1015
+ "model = QuestionDistilBERT(mod)\n",
1016
+ "model.to(device)"
1017
+ ]
1018
+ },
1019
+ {
1020
+ "cell_type": "code",
1021
+ "execution_count": 66,
1022
+ "id": "340857f9",
1023
+ "metadata": {},
1024
+ "outputs": [
1025
+ {
1026
+ "name": "stdout",
1027
+ "output_type": "stream",
1028
+ "text": [
1029
+ "+---------------------------------------+------------+\n",
1030
+ "| Modules | Parameters |\n",
1031
+ "+---------------------------------------+------------+\n",
1032
+ "| te.layers.0.self_attn.in_proj_weight | 1769472 |\n",
1033
+ "| te.layers.0.self_attn.in_proj_bias | 2304 |\n",
1034
+ "| te.layers.0.self_attn.out_proj.weight | 589824 |\n",
1035
+ "| te.layers.0.self_attn.out_proj.bias | 768 |\n",
1036
+ "| te.layers.0.linear1.weight | 1572864 |\n",
1037
+ "| te.layers.0.linear1.bias | 2048 |\n",
1038
+ "| te.layers.0.linear2.weight | 1572864 |\n",
1039
+ "| te.layers.0.linear2.bias | 768 |\n",
1040
+ "| te.layers.0.norm1.weight | 768 |\n",
1041
+ "| te.layers.0.norm1.bias | 768 |\n",
1042
+ "| te.layers.0.norm2.weight | 768 |\n",
1043
+ "| te.layers.0.norm2.bias | 768 |\n",
1044
+ "| te.layers.1.self_attn.in_proj_weight | 1769472 |\n",
1045
+ "| te.layers.1.self_attn.in_proj_bias | 2304 |\n",
1046
+ "| te.layers.1.self_attn.out_proj.weight | 589824 |\n",
1047
+ "| te.layers.1.self_attn.out_proj.bias | 768 |\n",
1048
+ "| te.layers.1.linear1.weight | 1572864 |\n",
1049
+ "| te.layers.1.linear1.bias | 2048 |\n",
1050
+ "| te.layers.1.linear2.weight | 1572864 |\n",
1051
+ "| te.layers.1.linear2.bias | 768 |\n",
1052
+ "| te.layers.1.norm1.weight | 768 |\n",
1053
+ "| te.layers.1.norm1.bias | 768 |\n",
1054
+ "| te.layers.1.norm2.weight | 768 |\n",
1055
+ "| te.layers.1.norm2.bias | 768 |\n",
1056
+ "| te.layers.2.self_attn.in_proj_weight | 1769472 |\n",
1057
+ "| te.layers.2.self_attn.in_proj_bias | 2304 |\n",
1058
+ "| te.layers.2.self_attn.out_proj.weight | 589824 |\n",
1059
+ "| te.layers.2.self_attn.out_proj.bias | 768 |\n",
1060
+ "| te.layers.2.linear1.weight | 1572864 |\n",
1061
+ "| te.layers.2.linear1.bias | 2048 |\n",
1062
+ "| te.layers.2.linear2.weight | 1572864 |\n",
1063
+ "| te.layers.2.linear2.bias | 768 |\n",
1064
+ "| te.layers.2.norm1.weight | 768 |\n",
1065
+ "| te.layers.2.norm1.bias | 768 |\n",
1066
+ "| te.layers.2.norm2.weight | 768 |\n",
1067
+ "| te.layers.2.norm2.bias | 768 |\n",
1068
+ "| classifier.2.weight | 393216 |\n",
1069
+ "| classifier.2.bias | 512 |\n",
1070
+ "| classifier.5.weight | 131072 |\n",
1071
+ "| classifier.5.bias | 256 |\n",
1072
+ "| classifier.8.weight | 32768 |\n",
1073
+ "| classifier.8.bias | 128 |\n",
1074
+ "| classifier.11.weight | 8192 |\n",
1075
+ "| classifier.11.bias | 64 |\n",
1076
+ "| classifier.14.weight | 128 |\n",
1077
+ "| classifier.14.bias | 2 |\n",
1078
+ "+---------------------------------------+------------+\n",
1079
+ "Total Trainable Params: 17108290\n"
1080
+ ]
1081
+ },
1082
+ {
1083
+ "data": {
1084
+ "text/plain": [
1085
+ "17108290"
1086
+ ]
1087
+ },
1088
+ "execution_count": 66,
1089
+ "metadata": {},
1090
+ "output_type": "execute_result"
1091
+ }
1092
+ ],
1093
+ "source": [
1094
+ "count_parameters(model)"
1095
+ ]
1096
+ },
1097
+ {
1098
+ "cell_type": "markdown",
1099
+ "id": "9babd013",
1100
+ "metadata": {},
1101
+ "source": [
1102
+ "### Testing the model\n",
1103
+ "This is the same procedure as in `distilbert.ipynb`. "
1104
+ ]
1105
+ },
1106
+ {
1107
+ "cell_type": "code",
1108
+ "execution_count": 67,
1109
+ "id": "694c828b",
1110
+ "metadata": {},
1111
+ "outputs": [],
1112
+ "source": [
1113
+ "# get smaller dataset\n",
1114
+ "batch_size = 8\n",
1115
+ "test_ds = Dataset(squad_paths = squad_paths[:2], natural_question_paths=None, hotpotqa_paths=None, tokenizer=tokenizer)\n",
1116
+ "test_ds_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)\n",
1117
+ "optim=torch.optim.Adam(model.parameters())"
1118
+ ]
1119
+ },
1120
+ {
1121
+ "cell_type": "code",
1122
+ "execution_count": 68,
1123
+ "id": "a76587df",
1124
+ "metadata": {},
1125
+ "outputs": [
1126
+ {
1127
+ "name": "stdout",
1128
+ "output_type": "stream",
1129
+ "text": [
1130
+ "Passed\n"
1131
+ ]
1132
+ }
1133
+ ],
1134
+ "source": [
1135
+ "test_model(model, optim, test_ds_loader, device)"
1136
+ ]
1137
+ },
1138
+ {
1139
+ "cell_type": "markdown",
1140
+ "id": "7c326e8e",
1141
+ "metadata": {},
1142
+ "source": [
1143
+ "### Training the model\n",
1144
+ "* Parameter Tuning:\n",
1145
+ " * Learning Rate: I experimented with several values, 1e-4 seemed to work best for me. 1e-3 was very unstable and 1e-5 was too small.\n",
1146
+ " * Gradient Clipping: I experimented with this, but the difference was only minimal\n",
1147
+ "\n",
1148
+ "Data:\n",
1149
+ "* I first used only the SQuAD dataset, but generalisation is a problem\n",
1150
+ " * The dataset is realtively small and we often have entries with the same context but different questions\n",
1151
+ " * I believe, the diversity is not big enough to train a fully functional model\n",
1152
+ "* Hence, I included the Natural Questions dataset too\n",
1153
+ " * It is however a lot more messy - I elaborated a bit more on this in `load_data.ipynb`\n",
1154
+ "* Also the hotpotqa data was used\n",
1155
+ "\n",
1156
+ "Tested with: \n",
1157
+ "* 3 Linear Layers\n",
1158
+ " * Training Error high - needed more layers\n",
1159
+ " * Already expected - this was mostly a Proof of Concept\n",
1160
+ "* 1 TransformerEncoder with 4 attention heads + 1 Linear Layer:\n",
1161
+ " * Training Error was high, still too simple\n",
1162
+ "* 1 TransformerEncoder with 8 heads + 1 Linear Layer:\n",
1163
+ " * Training Error gets lower, however stagnates at some point\n",
1164
+ " * Probably still too simple, it doesn't generalise either\n",
1165
+ "* 2 TransformerEncoder with 8 and 4 heads + 1 Linear Layer:\n",
1166
+ " * Loss gets down but doesn't go further after some time\n"
1167
+ ]
1168
+ },
1169
+ {
1170
+ "cell_type": "code",
1171
+ "execution_count": null,
1172
+ "id": "2e9f4bd3",
1173
+ "metadata": {},
1174
+ "outputs": [],
1175
+ "source": [
1176
+ "dataset = Dataset(squad_paths = squad_paths, natural_question_paths=nat_paths, hotpotqa_paths=hotpotqa_paths, tokenizer=tokenizer)\n",
1177
+ "loader = torch.utils.data.DataLoader(dataset, batch_size=8)\n",
1178
+ "\n",
1179
+ "test_dataset = Dataset(squad_paths = [str(x) for x in Path('data/test_squad/').glob('**/*.txt')], \n",
1180
+ " natural_question_paths=None, \n",
1181
+ " hotpotqa_paths = None, tokenizer=tokenizer)\n",
1182
+ "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4)"
1183
+ ]
1184
+ },
1185
+ {
1186
+ "cell_type": "code",
1187
+ "execution_count": 26,
1188
+ "id": "03a6de37",
1189
+ "metadata": {},
1190
+ "outputs": [],
1191
+ "source": [
1192
+ "model = QuestionDistilBERT(mod)"
1193
+ ]
1194
+ },
1195
+ {
1196
+ "cell_type": "code",
1197
+ "execution_count": 41,
1198
+ "id": "ed854b73",
1199
+ "metadata": {},
1200
+ "outputs": [],
1201
+ "source": [
1202
+ "from torch.optim import AdamW, RMSprop\n",
1203
+ "\n",
1204
+ "model.train()\n",
1205
+ "optim = RMSprop(model.parameters(), lr=1e-4)"
1206
+ ]
1207
+ },
1208
+ {
1209
+ "cell_type": "code",
1210
+ "execution_count": 42,
1211
+ "id": "79fdfcc9",
1212
+ "metadata": {},
1213
+ "outputs": [],
1214
+ "source": [
1215
+ "from torch.utils.tensorboard import SummaryWriter\n",
1216
+ "writer = SummaryWriter()"
1217
+ ]
1218
+ },
1219
+ {
1220
+ "cell_type": "code",
1221
+ "execution_count": null,
1222
+ "id": "f7bddb43",
1223
+ "metadata": {},
1224
+ "outputs": [
1225
+ {
1226
+ "data": {
1227
+ "application/vnd.jupyter.widget-view+json": {
1228
+ "model_id": "5e9e74167c4b4b22b3218f4ca3c5abf0",
1229
+ "version_major": 2,
1230
+ "version_minor": 0
1231
+ },
1232
+ "text/plain": [
1233
+ " 0%| | 0/21750 [00:00<?, ?it/s]"
1234
+ ]
1235
+ },
1236
+ "metadata": {},
1237
+ "output_type": "display_data"
1238
+ },
1239
+ {
1240
+ "name": "stdout",
1241
+ "output_type": "stream",
1242
+ "text": [
1243
+ "Mean Training Error 3.8791405910185013\n"
1244
+ ]
1245
+ },
1246
+ {
1247
+ "data": {
1248
+ "application/vnd.jupyter.widget-view+json": {
1249
+ "model_id": "f3ce562fc61d4bfc83a4860eb06bc20c",
1250
+ "version_major": 2,
1251
+ "version_minor": 0
1252
+ },
1253
+ "text/plain": [
1254
+ " 0%| | 0/1250 [00:00<?, ?it/s]"
1255
+ ]
1256
+ },
1257
+ "metadata": {},
1258
+ "output_type": "display_data"
1259
+ },
1260
+ {
1261
+ "name": "stdout",
1262
+ "output_type": "stream",
1263
+ "text": [
1264
+ "Mean Test Error 3.7705092002868654\n"
1265
+ ]
1266
+ },
1267
+ {
1268
+ "data": {
1269
+ "application/vnd.jupyter.widget-view+json": {
1270
+ "model_id": "2e84e21cedd446a0a5f5a40501711d1c",
1271
+ "version_major": 2,
1272
+ "version_minor": 0
1273
+ },
1274
+ "text/plain": [
1275
+ " 0%| | 0/21750 [00:00<?, ?it/s]"
1276
+ ]
1277
+ },
1278
+ "metadata": {},
1279
+ "output_type": "display_data"
1280
+ },
1281
+ {
1282
+ "name": "stdout",
1283
+ "output_type": "stream",
1284
+ "text": [
1285
+ "Mean Training Error 3.7389922174091996\n"
1286
+ ]
1287
+ },
1288
+ {
1289
+ "data": {
1290
+ "application/vnd.jupyter.widget-view+json": {
1291
+ "model_id": "07135c48be0146498cd37d767c1ee6ab",
1292
+ "version_major": 2,
1293
+ "version_minor": 0
1294
+ },
1295
+ "text/plain": [
1296
+ " 0%| | 0/1250 [00:00<?, ?it/s]"
1297
+ ]
1298
+ },
1299
+ "metadata": {},
1300
+ "output_type": "display_data"
1301
+ },
1302
+ {
1303
+ "name": "stdout",
1304
+ "output_type": "stream",
1305
+ "text": [
1306
+ "Mean Test Error 3.7443671816825868\n"
1307
+ ]
1308
+ },
1309
+ {
1310
+ "data": {
1311
+ "application/vnd.jupyter.widget-view+json": {
1312
+ "model_id": "e9a51fbabc7043c2819a68e247e4a3ec",
1313
+ "version_major": 2,
1314
+ "version_minor": 0
1315
+ },
1316
+ "text/plain": [
1317
+ " 0%| | 0/21750 [00:00<?, ?it/s]"
1318
+ ]
1319
+ },
1320
+ "metadata": {},
1321
+ "output_type": "display_data"
1322
+ },
1323
+ {
1324
+ "name": "stdout",
1325
+ "output_type": "stream",
1326
+ "text": [
1327
+ "Mean Training Error 3.7031057048117977\n"
1328
+ ]
1329
+ },
1330
+ {
1331
+ "data": {
1332
+ "application/vnd.jupyter.widget-view+json": {
1333
+ "model_id": "bfdbcc9fe32542a19c47bc1d7704400e",
1334
+ "version_major": 2,
1335
+ "version_minor": 0
1336
+ },
1337
+ "text/plain": [
1338
+ " 0%| | 0/1250 [00:00<?, ?it/s]"
1339
+ ]
1340
+ },
1341
+ "metadata": {},
1342
+ "output_type": "display_data"
1343
+ },
1344
+ {
1345
+ "name": "stdout",
1346
+ "output_type": "stream",
1347
+ "text": [
1348
+ "Mean Test Error 3.743248237323761\n"
1349
+ ]
1350
+ },
1351
+ {
1352
+ "data": {
1353
+ "application/vnd.jupyter.widget-view+json": {
1354
+ "model_id": "81fd1278b22643dc9fb3ac306533a240",
1355
+ "version_major": 2,
1356
+ "version_minor": 0
1357
+ },
1358
+ "text/plain": [
1359
+ " 0%| | 0/21750 [00:00<?, ?it/s]"
1360
+ ]
1361
+ },
1362
+ "metadata": {},
1363
+ "output_type": "display_data"
1364
+ },
1365
+ {
1366
+ "name": "stdout",
1367
+ "output_type": "stream",
1368
+ "text": [
1369
+ "Mean Training Error 3.6711661003430685\n"
1370
+ ]
1371
+ },
1372
+ {
1373
+ "data": {
1374
+ "application/vnd.jupyter.widget-view+json": {
1375
+ "model_id": "8b38d6cd44e048ec8bcd6b5cb86cce16",
1376
+ "version_major": 2,
1377
+ "version_minor": 0
1378
+ },
1379
+ "text/plain": [
1380
+ " 0%| | 0/1250 [00:00<?, ?it/s]"
1381
+ ]
1382
+ },
1383
+ "metadata": {},
1384
+ "output_type": "display_data"
1385
+ },
1386
+ {
1387
+ "name": "stdout",
1388
+ "output_type": "stream",
1389
+ "text": [
1390
+ "Mean Test Error 3.740310479736328\n"
1391
+ ]
1392
+ },
1393
+ {
1394
+ "data": {
1395
+ "application/vnd.jupyter.widget-view+json": {
1396
+ "model_id": "825248aa3f934f4aade9d973e6f3b43e",
1397
+ "version_major": 2,
1398
+ "version_minor": 0
1399
+ },
1400
+ "text/plain": [
1401
+ " 0%| | 0/21750 [00:00<?, ?it/s]"
1402
+ ]
1403
+ },
1404
+ "metadata": {},
1405
+ "output_type": "display_data"
1406
+ },
1407
+ {
1408
+ "name": "stdout",
1409
+ "output_type": "stream",
1410
+ "text": [
1411
+ "Mean Training Error 3.6591619139813827\n"
1412
+ ]
1413
+ },
1414
+ {
1415
+ "data": {
1416
+ "application/vnd.jupyter.widget-view+json": {
1417
+ "model_id": "edceb7af0ec6450997820967638c12db",
1418
+ "version_major": 2,
1419
+ "version_minor": 0
1420
+ },
1421
+ "text/plain": [
1422
+ " 0%| | 0/1250 [00:00<?, ?it/s]"
1423
+ ]
1424
+ },
1425
+ "metadata": {},
1426
+ "output_type": "display_data"
1427
+ },
1428
+ {
1429
+ "name": "stdout",
1430
+ "output_type": "stream",
1431
+ "text": [
1432
+ "Mean Test Error 3.8138498876571654\n"
1433
+ ]
1434
+ },
1435
+ {
1436
+ "data": {
1437
+ "application/vnd.jupyter.widget-view+json": {
1438
+ "model_id": "27e903eb0d0f4f949c234e4faf4277a1",
1439
+ "version_major": 2,
1440
+ "version_minor": 0
1441
+ },
1442
+ "text/plain": [
1443
+ " 0%| | 0/21750 [00:00<?, ?it/s]"
1444
+ ]
1445
+ },
1446
+ "metadata": {},
1447
+ "output_type": "display_data"
1448
+ }
1449
+ ],
1450
+ "source": [
1451
+ "epochs = 20\n",
1452
+ "\n",
1453
+ "for epoch in range(epochs):\n",
1454
+ " loop = tqdm(loader, leave=True)\n",
1455
+ " model.train()\n",
1456
+ " mean_training_error = []\n",
1457
+ " for batch in loop:\n",
1458
+ " optim.zero_grad()\n",
1459
+ " \n",
1460
+ " input_ids = batch['input_ids'].to(device)\n",
1461
+ " attention_mask = batch['attention_mask'].to(device)\n",
1462
+ " start = batch['start_positions'].to(device)\n",
1463
+ " end = batch['end_positions'].to(device)\n",
1464
+ " \n",
1465
+ " outputs = model(input_ids, attention_mask=attention_mask, start_positions=start, end_positions=end)\n",
1466
+ " \n",
1467
+ " loss = outputs['loss']\n",
1468
+ " loss.backward()\n",
1469
+ " \n",
1470
+ " optim.step()\n",
1471
+ " mean_training_error.append(loss.item())\n",
1472
+ " loop.set_description(f'Epoch {epoch}')\n",
1473
+ " loop.set_postfix(loss=loss.item())\n",
1474
+ " print(\"Mean Training Error\", np.mean(mean_training_error))\n",
1475
+ " writer.add_scalar(\"Loss/train\", np.mean(mean_training_error), epoch)\n",
1476
+ " \n",
1477
+ " loop = tqdm(test_loader, leave=True)\n",
1478
+ " model.eval()\n",
1479
+ " mean_test_error = []\n",
1480
+ " for batch in loop:\n",
1481
+ " \n",
1482
+ " input_ids = batch['input_ids'].to(device)\n",
1483
+ " attention_mask = batch['attention_mask'].to(device)\n",
1484
+ " start = batch['start_positions'].to(device)\n",
1485
+ " end = batch['end_positions'].to(device)\n",
1486
+ " \n",
1487
+ " outputs = model(input_ids, attention_mask=attention_mask, start_positions=start, end_positions=end)\n",
1488
+ " # print(torch.argmax(outputs['start_logits'],axis=1), torch.argmax(outputs['end_logits'], axis=1), start, end)\n",
1489
+ " loss = outputs['loss']\n",
1490
+ " \n",
1491
+ " mean_test_error.append(loss.item())\n",
1492
+ " loop.set_description(f'Epoch {epoch} Testset')\n",
1493
+ " loop.set_postfix(loss=loss.item())\n",
1494
+ " print(\"Mean Test Error\", np.mean(mean_test_error))\n",
1495
+ " writer.add_scalar(\"Loss/test\", np.mean(mean_test_error), epoch)"
1496
+ ]
1497
+ },
1498
+ {
1499
+ "cell_type": "code",
1500
+ "execution_count": 238,
1501
+ "id": "a9d6af2e",
1502
+ "metadata": {},
1503
+ "outputs": [],
1504
+ "source": [
1505
+ "writer.close()"
1506
+ ]
1507
+ },
1508
+ {
1509
+ "cell_type": "code",
1510
+ "execution_count": 33,
1511
+ "id": "ba43447e",
1512
+ "metadata": {},
1513
+ "outputs": [],
1514
+ "source": [
1515
+ "torch.save(model.state_dict(), \"distilbert_qa.model\")"
1516
+ ]
1517
+ },
1518
+ {
1519
+ "cell_type": "code",
1520
+ "execution_count": 34,
1521
+ "id": "ffc49aca",
1522
+ "metadata": {},
1523
+ "outputs": [
1524
+ {
1525
+ "data": {
1526
+ "text/plain": [
1527
+ "<All keys matched successfully>"
1528
+ ]
1529
+ },
1530
+ "execution_count": 34,
1531
+ "metadata": {},
1532
+ "output_type": "execute_result"
1533
+ }
1534
+ ],
1535
+ "source": [
1536
+ "model = QuestionDistilBERT(mod)\n",
1537
+ "model.load_state_dict(torch.load(\"distilbert_qa.model\"))"
1538
+ ]
1539
+ },
1540
+ {
1541
+ "cell_type": "code",
1542
+ "execution_count": 35,
1543
+ "id": "730a86c1",
1544
+ "metadata": {},
1545
+ "outputs": [
1546
+ {
1547
+ "name": "stderr",
1548
+ "output_type": "stream",
1549
+ "text": [
1550
+ "100%|鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅| 2500/2500 [02:57<00:00, 14.09it/s]"
1551
+ ]
1552
+ },
1553
+ {
1554
+ "name": "stdout",
1555
+ "output_type": "stream",
1556
+ "text": [
1557
+ "Mean EM: 0.0479\n",
1558
+ "Mean F-1: 0.08989175857485086\n"
1559
+ ]
1560
+ },
1561
+ {
1562
+ "name": "stderr",
1563
+ "output_type": "stream",
1564
+ "text": [
1565
+ "\n"
1566
+ ]
1567
+ }
1568
+ ],
1569
+ "source": [
1570
+ "eval_test_set(model, tokenizer, test_loader, device)"
1571
+ ]
1572
+ },
1573
+ {
1574
+ "cell_type": "markdown",
1575
+ "id": "bd1c7076",
1576
+ "metadata": {},
1577
+ "source": [
1578
+ "## Reuse Layer\n",
1579
+ "This was inspired by how well the original model with just one classification head worked. I felt like the main problem with the previous model was the lack of structure which was already in the layers, combined with the massive amount of resources needed for a Transformer.\n",
1580
+ "\n",
1581
+ "Hence, I tried cloning the last (and then last two) layers of the DistilBERT model, putting a classifier on top and using this as the head. The base DistilBERT model is completely frozen. This worked extremely well, while we only fine-tune about 21% of the parameters (14 Mio as opposed to 66 Mio!) we did before. Below you can see the results.\n",
1582
+ "\n",
1583
+ "### Last DistilBERT layer\n",
1584
+ "\n",
1585
+ "Dropout 0.1 and RMSprop 1e-4:\n",
1586
+ "* Mean EM: 0.3888\n",
1587
+ "* Mean F-1: 0.5122932744694068\n",
1588
+ "\n",
1589
+ "Dropout 0.25: very early stagnating\n",
1590
+ "* Mean EM: 0.3552\n",
1591
+ "* Mean F-1: 0.4711235721312687\n",
1592
+ "\n",
1593
+ "Dropout 0.15: seems to work well - training and test error stagnate around 1.7 and 1.8 but good generalisation (need to add more layers)\n",
1594
+ "* Mean EM: 0.4119\n",
1595
+ "* Mean F-1: 0.5296387232893214\n",
1596
+ "\n",
1597
+ "### Last DitilBERT layer + more Dense layers\n",
1598
+ "Dropout 0.15 + 4 dense layers((786-512)-(512-256)-(256-128)-(128-2)) & ReLU: doesn't work too well - stagnates at around 2.4\n",
1599
+ "\n",
1600
+ "### Last two DistilBERT layers\n",
1601
+ "Dropout 0.1 but last 2 DistilBERT layers: works very well, but early overfitting - maybe use more data\n",
1602
+ "* Mean EM: 0.458\n",
1603
+ "* Mean F-1: 0.6003368353673634\n",
1604
+ "\n",
1605
+ "Dropout 0.1 - last 2 distilbert layers: all data\n",
1606
+ "* Mean EM: 0.484\n",
1607
+ "* Mean F-1: 0.6344960035215299\n",
1608
+ "\n",
1609
+ "Dropout 0.15 - **BEST**\n",
1610
+ "* Mean EM: 0.5178\n",
1611
+ "* Mean F-1: 0.6671140689626448\n",
1612
+ "\n",
1613
+ "Dropout 0.2 - doesn't work too well\n",
1614
+ "* Mean EM: 0.4353\n",
1615
+ "* Mean F-1: 0.5776847879304647\n"
1616
+ ]
1617
+ },
1618
+ {
1619
+ "cell_type": "code",
1620
+ "execution_count": 69,
1621
+ "id": "654e09e8",
1622
+ "metadata": {},
1623
+ "outputs": [],
1624
+ "source": [
1625
+ "dataset = Dataset(squad_paths = squad_paths, natural_question_paths=None, hotpotqa_paths=hotpotqa_paths, tokenizer=tokenizer)\n",
1626
+ "loader = torch.utils.data.DataLoader(dataset, batch_size=8)\n",
1627
+ "\n",
1628
+ "test_dataset = Dataset(squad_paths = [str(x) for x in Path('data/test_squad/').glob('**/*.txt')], \n",
1629
+ " natural_question_paths=None, \n",
1630
+ " hotpotqa_paths = None, tokenizer=tokenizer)\n",
1631
+ "test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4)"
1632
+ ]
1633
+ },
1634
+ {
1635
+ "cell_type": "code",
1636
+ "execution_count": 70,
1637
+ "id": "707c0cb5",
1638
+ "metadata": {},
1639
+ "outputs": [
1640
+ {
1641
+ "data": {
1642
+ "text/plain": [
1643
+ "ReuseQuestionDistilBERT(\n",
1644
+ " (te): ModuleList(\n",
1645
+ " (0): TransformerBlock(\n",
1646
+ " (attention): MultiHeadSelfAttention(\n",
1647
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1648
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1649
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1650
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1651
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1652
+ " )\n",
1653
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1654
+ " (ffn): FFN(\n",
1655
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1656
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
1657
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
1658
+ " (activation): GELUActivation()\n",
1659
+ " )\n",
1660
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1661
+ " )\n",
1662
+ " (1): TransformerBlock(\n",
1663
+ " (attention): MultiHeadSelfAttention(\n",
1664
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1665
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1666
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1667
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1668
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1669
+ " )\n",
1670
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1671
+ " (ffn): FFN(\n",
1672
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1673
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
1674
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
1675
+ " (activation): GELUActivation()\n",
1676
+ " )\n",
1677
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1678
+ " )\n",
1679
+ " )\n",
1680
+ " (distilbert): DistilBertModel(\n",
1681
+ " (embeddings): Embeddings(\n",
1682
+ " (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
1683
+ " (position_embeddings): Embedding(512, 768)\n",
1684
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1685
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1686
+ " )\n",
1687
+ " (transformer): Transformer(\n",
1688
+ " (layer): ModuleList(\n",
1689
+ " (0): TransformerBlock(\n",
1690
+ " (attention): MultiHeadSelfAttention(\n",
1691
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1692
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1693
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1694
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1695
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1696
+ " )\n",
1697
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1698
+ " (ffn): FFN(\n",
1699
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1700
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
1701
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
1702
+ " (activation): GELUActivation()\n",
1703
+ " )\n",
1704
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1705
+ " )\n",
1706
+ " (1): TransformerBlock(\n",
1707
+ " (attention): MultiHeadSelfAttention(\n",
1708
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1709
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1710
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1711
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1712
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1713
+ " )\n",
1714
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1715
+ " (ffn): FFN(\n",
1716
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1717
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
1718
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
1719
+ " (activation): GELUActivation()\n",
1720
+ " )\n",
1721
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1722
+ " )\n",
1723
+ " (2): TransformerBlock(\n",
1724
+ " (attention): MultiHeadSelfAttention(\n",
1725
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1726
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1727
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1728
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1729
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1730
+ " )\n",
1731
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1732
+ " (ffn): FFN(\n",
1733
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1734
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
1735
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
1736
+ " (activation): GELUActivation()\n",
1737
+ " )\n",
1738
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1739
+ " )\n",
1740
+ " (3): TransformerBlock(\n",
1741
+ " (attention): MultiHeadSelfAttention(\n",
1742
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1743
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1744
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1745
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1746
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1747
+ " )\n",
1748
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1749
+ " (ffn): FFN(\n",
1750
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1751
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
1752
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
1753
+ " (activation): GELUActivation()\n",
1754
+ " )\n",
1755
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1756
+ " )\n",
1757
+ " (4): TransformerBlock(\n",
1758
+ " (attention): MultiHeadSelfAttention(\n",
1759
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1760
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1761
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1762
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1763
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1764
+ " )\n",
1765
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1766
+ " (ffn): FFN(\n",
1767
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1768
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
1769
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
1770
+ " (activation): GELUActivation()\n",
1771
+ " )\n",
1772
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1773
+ " )\n",
1774
+ " (5): TransformerBlock(\n",
1775
+ " (attention): MultiHeadSelfAttention(\n",
1776
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1777
+ " (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1778
+ " (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1779
+ " (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1780
+ " (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
1781
+ " )\n",
1782
+ " (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1783
+ " (ffn): FFN(\n",
1784
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
1785
+ " (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
1786
+ " (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
1787
+ " (activation): GELUActivation()\n",
1788
+ " )\n",
1789
+ " (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
1790
+ " )\n",
1791
+ " )\n",
1792
+ " )\n",
1793
+ " )\n",
1794
+ " (relu): ReLU()\n",
1795
+ " (dropout): Dropout(p=0.15, inplace=False)\n",
1796
+ " (classifier): Linear(in_features=768, out_features=2, bias=True)\n",
1797
+ ")"
1798
+ ]
1799
+ },
1800
+ "execution_count": 70,
1801
+ "metadata": {},
1802
+ "output_type": "execute_result"
1803
+ }
1804
+ ],
1805
+ "source": [
1806
+ "model = DistilBertForMaskedLM.from_pretrained(\"distilbert-base-uncased\")\n",
1807
+ "config = DistilBertConfig.from_pretrained(\"distilbert-base-uncased\")\n",
1808
+ "mod = model.distilbert\n",
1809
+ "\n",
1810
+ "device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
1811
+ "model = ReuseQuestionDistilBERT(mod)\n",
1812
+ "model.to(device)"
1813
+ ]
1814
+ },
1815
+ {
1816
+ "cell_type": "code",
1817
+ "execution_count": 71,
1818
+ "id": "d2c6bff5",
1819
+ "metadata": {},
1820
+ "outputs": [
1821
+ {
1822
+ "name": "stdout",
1823
+ "output_type": "stream",
1824
+ "text": [
1825
+ "+-------------------------------+------------+\n",
1826
+ "| Modules | Parameters |\n",
1827
+ "+-------------------------------+------------+\n",
1828
+ "| te.0.attention.q_lin.weight | 589824 |\n",
1829
+ "| te.0.attention.q_lin.bias | 768 |\n",
1830
+ "| te.0.attention.k_lin.weight | 589824 |\n",
1831
+ "| te.0.attention.k_lin.bias | 768 |\n",
1832
+ "| te.0.attention.v_lin.weight | 589824 |\n",
1833
+ "| te.0.attention.v_lin.bias | 768 |\n",
1834
+ "| te.0.attention.out_lin.weight | 589824 |\n",
1835
+ "| te.0.attention.out_lin.bias | 768 |\n",
1836
+ "| te.0.sa_layer_norm.weight | 768 |\n",
1837
+ "| te.0.sa_layer_norm.bias | 768 |\n",
1838
+ "| te.0.ffn.lin1.weight | 2359296 |\n",
1839
+ "| te.0.ffn.lin1.bias | 3072 |\n",
1840
+ "| te.0.ffn.lin2.weight | 2359296 |\n",
1841
+ "| te.0.ffn.lin2.bias | 768 |\n",
1842
+ "| te.0.output_layer_norm.weight | 768 |\n",
1843
+ "| te.0.output_layer_norm.bias | 768 |\n",
1844
+ "| te.1.attention.q_lin.weight | 589824 |\n",
1845
+ "| te.1.attention.q_lin.bias | 768 |\n",
1846
+ "| te.1.attention.k_lin.weight | 589824 |\n",
1847
+ "| te.1.attention.k_lin.bias | 768 |\n",
1848
+ "| te.1.attention.v_lin.weight | 589824 |\n",
1849
+ "| te.1.attention.v_lin.bias | 768 |\n",
1850
+ "| te.1.attention.out_lin.weight | 589824 |\n",
1851
+ "| te.1.attention.out_lin.bias | 768 |\n",
1852
+ "| te.1.sa_layer_norm.weight | 768 |\n",
1853
+ "| te.1.sa_layer_norm.bias | 768 |\n",
1854
+ "| te.1.ffn.lin1.weight | 2359296 |\n",
1855
+ "| te.1.ffn.lin1.bias | 3072 |\n",
1856
+ "| te.1.ffn.lin2.weight | 2359296 |\n",
1857
+ "| te.1.ffn.lin2.bias | 768 |\n",
1858
+ "| te.1.output_layer_norm.weight | 768 |\n",
1859
+ "| te.1.output_layer_norm.bias | 768 |\n",
1860
+ "| classifier.weight | 1536 |\n",
1861
+ "| classifier.bias | 2 |\n",
1862
+ "+-------------------------------+------------+\n",
1863
+ "Total Trainable Params: 14177282\n"
1864
+ ]
1865
+ },
1866
+ {
1867
+ "data": {
1868
+ "text/plain": [
1869
+ "14177282"
1870
+ ]
1871
+ },
1872
+ "execution_count": 71,
1873
+ "metadata": {},
1874
+ "output_type": "execute_result"
1875
+ }
1876
+ ],
1877
+ "source": [
1878
+ "count_parameters(model)"
1879
+ ]
1880
+ },
1881
+ {
1882
+ "cell_type": "markdown",
1883
+ "id": "c386c2eb",
1884
+ "metadata": {},
1885
+ "source": [
1886
+ "### Testing the Model"
1887
+ ]
1888
+ },
1889
+ {
1890
+ "cell_type": "code",
1891
+ "execution_count": 72,
1892
+ "id": "818deed3",
1893
+ "metadata": {},
1894
+ "outputs": [],
1895
+ "source": [
1896
+ "# get smaller dataset\n",
1897
+ "batch_size = 8\n",
1898
+ "test_ds = Dataset(squad_paths = squad_paths[:2], natural_question_paths=None, hotpotqa_paths=None, tokenizer=tokenizer)\n",
1899
+ "test_ds_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)\n",
1900
+ "optim=torch.optim.Adam(model.parameters())"
1901
+ ]
1902
+ },
1903
+ {
1904
+ "cell_type": "code",
1905
+ "execution_count": 73,
1906
+ "id": "9da40760",
1907
+ "metadata": {},
1908
+ "outputs": [
1909
+ {
1910
+ "name": "stdout",
1911
+ "output_type": "stream",
1912
+ "text": [
1913
+ "Passed\n"
1914
+ ]
1915
+ }
1916
+ ],
1917
+ "source": [
1918
+ "test_model(model, optim, test_ds_loader, device)"
1919
+ ]
1920
+ },
1921
+ {
1922
+ "cell_type": "markdown",
1923
+ "id": "c3f80248",
1924
+ "metadata": {},
1925
+ "source": [
1926
+ "### Model Training"
1927
+ ]
1928
+ },
1929
+ {
1930
+ "cell_type": "code",
1931
+ "execution_count": 24,
1932
+ "id": "e1adabe6",
1933
+ "metadata": {},
1934
+ "outputs": [],
1935
+ "source": [
1936
+ "from torch.optim import AdamW, RMSprop\n",
1937
+ "\n",
1938
+ "model.train()\n",
1939
+ "optim = AdamW(model.parameters(), lr=1e-4)"
1940
+ ]
1941
+ },
1942
+ {
1943
+ "cell_type": "code",
1944
+ "execution_count": 25,
1945
+ "id": "efe1cbd5",
1946
+ "metadata": {},
1947
+ "outputs": [
1948
+ {
1949
+ "data": {
1950
+ "application/vnd.jupyter.widget-view+json": {
1951
+ "model_id": "8785757b04214102830ded36c1392c8d",
1952
+ "version_major": 2,
1953
+ "version_minor": 0
1954
+ },
1955
+ "text/plain": [
1956
+ " 0%| | 0/35000 [00:00<?, ?it/s]"
1957
+ ]
1958
+ },
1959
+ "metadata": {},
1960
+ "output_type": "display_data"
1961
+ },
1962
+ {
1963
+ "name": "stdout",
1964
+ "output_type": "stream",
1965
+ "text": [
1966
+ "Mean Training Error 2.6535016193100383\n"
1967
+ ]
1968
+ },
1969
+ {
1970
+ "data": {
1971
+ "application/vnd.jupyter.widget-view+json": {
1972
+ "model_id": "836f5365498642fa9ae891a86dca5892",
1973
+ "version_major": 2,
1974
+ "version_minor": 0
1975
+ },
1976
+ "text/plain": [
1977
+ " 0%| | 0/2500 [00:00<?, ?it/s]"
1978
+ ]
1979
+ },
1980
+ "metadata": {},
1981
+ "output_type": "display_data"
1982
+ },
1983
+ {
1984
+ "name": "stdout",
1985
+ "output_type": "stream",
1986
+ "text": [
1987
+ "Mean Test Error 2.384517493388057\n"
1988
+ ]
1989
+ },
1990
+ {
1991
+ "data": {
1992
+ "application/vnd.jupyter.widget-view+json": {
1993
+ "model_id": "981e1cef83a1477e920d1cdbffdfcde1",
1994
+ "version_major": 2,
1995
+ "version_minor": 0
1996
+ },
1997
+ "text/plain": [
1998
+ " 0%| | 0/35000 [00:00<?, ?it/s]"
1999
+ ]
2000
+ },
2001
+ "metadata": {},
2002
+ "output_type": "display_data"
2003
+ },
2004
+ {
2005
+ "name": "stdout",
2006
+ "output_type": "stream",
2007
+ "text": [
2008
+ "Mean Training Error 2.172889394424643\n"
2009
+ ]
2010
+ },
2011
+ {
2012
+ "data": {
2013
+ "application/vnd.jupyter.widget-view+json": {
2014
+ "model_id": "20a785e7fefb43239f1120992d2c3416",
2015
+ "version_major": 2,
2016
+ "version_minor": 0
2017
+ },
2018
+ "text/plain": [
2019
+ " 0%| | 0/2500 [00:00<?, ?it/s]"
2020
+ ]
2021
+ },
2022
+ "metadata": {},
2023
+ "output_type": "display_data"
2024
+ },
2025
+ {
2026
+ "name": "stdout",
2027
+ "output_type": "stream",
2028
+ "text": [
2029
+ "Mean Test Error 2.013008696398139\n"
2030
+ ]
2031
+ },
2032
+ {
2033
+ "data": {
2034
+ "application/vnd.jupyter.widget-view+json": {
2035
+ "model_id": "47831e65b1ed4be78e8e7cb24068b0c3",
2036
+ "version_major": 2,
2037
+ "version_minor": 0
2038
+ },
2039
+ "text/plain": [
2040
+ " 0%| | 0/35000 [00:00<?, ?it/s]"
2041
+ ]
2042
+ },
2043
+ "metadata": {},
2044
+ "output_type": "display_data"
2045
+ },
2046
+ {
2047
+ "name": "stdout",
2048
+ "output_type": "stream",
2049
+ "text": [
2050
+ "Mean Training Error 1.9743544759827\n"
2051
+ ]
2052
+ },
2053
+ {
2054
+ "data": {
2055
+ "application/vnd.jupyter.widget-view+json": {
2056
+ "model_id": "15904a3f930249fb944ea87184676e14",
2057
+ "version_major": 2,
2058
+ "version_minor": 0
2059
+ },
2060
+ "text/plain": [
2061
+ " 0%| | 0/2500 [00:00<?, ?it/s]"
2062
+ ]
2063
+ },
2064
+ "metadata": {},
2065
+ "output_type": "display_data"
2066
+ },
2067
+ {
2068
+ "name": "stdout",
2069
+ "output_type": "stream",
2070
+ "text": [
2071
+ "Mean Test Error 1.8922049684919418\n"
2072
+ ]
2073
+ },
2074
+ {
2075
+ "data": {
2076
+ "application/vnd.jupyter.widget-view+json": {
2077
+ "model_id": "108bdbf644d94d78910195992b9e2652",
2078
+ "version_major": 2,
2079
+ "version_minor": 0
2080
+ },
2081
+ "text/plain": [
2082
+ " 0%| | 0/35000 [00:00<?, ?it/s]"
2083
+ ]
2084
+ },
2085
+ "metadata": {},
2086
+ "output_type": "display_data"
2087
+ },
2088
+ {
2089
+ "name": "stdout",
2090
+ "output_type": "stream",
2091
+ "text": [
2092
+ "Mean Training Error 1.857202093189742\n"
2093
+ ]
2094
+ },
2095
+ {
2096
+ "data": {
2097
+ "application/vnd.jupyter.widget-view+json": {
2098
+ "model_id": "d6a75a6ab40d4a2599b7511bfc60bf83",
2099
+ "version_major": 2,
2100
+ "version_minor": 0
2101
+ },
2102
+ "text/plain": [
2103
+ " 0%| | 0/2500 [00:00<?, ?it/s]"
2104
+ ]
2105
+ },
2106
+ "metadata": {},
2107
+ "output_type": "display_data"
2108
+ },
2109
+ {
2110
+ "name": "stdout",
2111
+ "output_type": "stream",
2112
+ "text": [
2113
+ "Mean Test Error 1.793771461571753\n"
2114
+ ]
2115
+ },
2116
+ {
2117
+ "data": {
2118
+ "application/vnd.jupyter.widget-view+json": {
2119
+ "model_id": "d3468a6ba72a4f42b0e7cc77ee0a0011",
2120
+ "version_major": 2,
2121
+ "version_minor": 0
2122
+ },
2123
+ "text/plain": [
2124
+ " 0%| | 0/35000 [00:00<?, ?it/s]"
2125
+ ]
2126
+ },
2127
+ "metadata": {},
2128
+ "output_type": "display_data"
2129
+ },
2130
+ {
2131
+ "name": "stdout",
2132
+ "output_type": "stream",
2133
+ "text": [
2134
+ "Mean Training Error 1.7750537034896867\n"
2135
+ ]
2136
+ },
2137
+ {
2138
+ "data": {
2139
+ "application/vnd.jupyter.widget-view+json": {
2140
+ "model_id": "8aca0aa529d2452e8bd29fe7ada934f2",
2141
+ "version_major": 2,
2142
+ "version_minor": 0
2143
+ },
2144
+ "text/plain": [
2145
+ " 0%| | 0/2500 [00:00<?, ?it/s]"
2146
+ ]
2147
+ },
2148
+ "metadata": {},
2149
+ "output_type": "display_data"
2150
+ },
2151
+ {
2152
+ "name": "stdout",
2153
+ "output_type": "stream",
2154
+ "text": [
2155
+ "Mean Test Error 1.7466133671954274\n"
2156
+ ]
2157
+ },
2158
+ {
2159
+ "data": {
2160
+ "application/vnd.jupyter.widget-view+json": {
2161
+ "model_id": "e09abdfa63c841ce97f445ba9b3eeaa8",
2162
+ "version_major": 2,
2163
+ "version_minor": 0
2164
+ },
2165
+ "text/plain": [
2166
+ " 0%| | 0/35000 [00:00<?, ?it/s]"
2167
+ ]
2168
+ },
2169
+ "metadata": {},
2170
+ "output_type": "display_data"
2171
+ },
2172
+ {
2173
+ "name": "stdout",
2174
+ "output_type": "stream",
2175
+ "text": [
2176
+ "Mean Training Error 1.7097622096568346\n"
2177
+ ]
2178
+ },
2179
+ {
2180
+ "data": {
2181
+ "application/vnd.jupyter.widget-view+json": {
2182
+ "model_id": "0f49dd32d33e4f398be0942a59d735ce",
2183
+ "version_major": 2,
2184
+ "version_minor": 0
2185
+ },
2186
+ "text/plain": [
2187
+ " 0%| | 0/2500 [00:00<?, ?it/s]"
2188
+ ]
2189
+ },
2190
+ "metadata": {},
2191
+ "output_type": "display_data"
2192
+ },
2193
+ {
2194
+ "name": "stdout",
2195
+ "output_type": "stream",
2196
+ "text": [
2197
+ "Mean Test Error 1.7642206047609448\n"
2198
+ ]
2199
+ },
2200
+ {
2201
+ "data": {
2202
+ "application/vnd.jupyter.widget-view+json": {
2203
+ "model_id": "a493dd70ffb64cd19830e5dc98607979",
2204
+ "version_major": 2,
2205
+ "version_minor": 0
2206
+ },
2207
+ "text/plain": [
2208
+ " 0%| | 0/35000 [00:00<?, ?it/s]"
2209
+ ]
2210
+ },
2211
+ "metadata": {},
2212
+ "output_type": "display_data"
2213
+ },
2214
+ {
2215
+ "name": "stderr",
2216
+ "output_type": "stream",
2217
+ "text": [
2218
+ "\n",
2219
+ "KeyboardInterrupt\n",
2220
+ "\n"
2221
+ ]
2222
+ }
2223
+ ],
2224
+ "source": [
2225
+ "epochs = 16\n",
2226
+ "\n",
2227
+ "for epoch in range(epochs):\n",
2228
+ " loop = tqdm(loader, leave=True)\n",
2229
+ " model.train()\n",
2230
+ " mean_training_error = []\n",
2231
+ " for batch in loop:\n",
2232
+ " optim.zero_grad()\n",
2233
+ " \n",
2234
+ " input_ids = batch['input_ids'].to(device)\n",
2235
+ " attention_mask = batch['attention_mask'].to(device)\n",
2236
+ " start = batch['start_positions'].to(device)\n",
2237
+ " end = batch['end_positions'].to(device)\n",
2238
+ " \n",
2239
+ " outputs = model(input_ids, attention_mask=attention_mask, start_positions=start, end_positions=end)\n",
2240
+ " # print(torch.argmax(outputs['start_logits'],axis=1), torch.argmax(outputs['end_logits'], axis=1), start, end)\n",
2241
+ " loss = outputs['loss']\n",
2242
+ " loss.backward()\n",
2243
+ " # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)\n",
2244
+ " optim.step()\n",
2245
+ " mean_training_error.append(loss.item())\n",
2246
+ " loop.set_description(f'Epoch {epoch}')\n",
2247
+ " loop.set_postfix(loss=loss.item())\n",
2248
+ " print(\"Mean Training Error\", np.mean(mean_training_error))\n",
2249
+ " \n",
2250
+ " loop = tqdm(test_loader, leave=True)\n",
2251
+ " model.eval()\n",
2252
+ " mean_test_error = []\n",
2253
+ " for batch in loop:\n",
2254
+ " \n",
2255
+ " input_ids = batch['input_ids'].to(device)\n",
2256
+ " attention_mask = batch['attention_mask'].to(device)\n",
2257
+ " start = batch['start_positions'].to(device)\n",
2258
+ " end = batch['end_positions'].to(device)\n",
2259
+ " \n",
2260
+ " outputs = model(input_ids, attention_mask=attention_mask, start_positions=start, end_positions=end)\n",
2261
+ " # print(torch.argmax(outputs['start_logits'],axis=1), torch.argmax(outputs['end_logits'], axis=1), start, end)\n",
2262
+ " loss = outputs['loss']\n",
2263
+ " \n",
2264
+ " mean_test_error.append(loss.item())\n",
2265
+ " loop.set_description(f'Epoch {epoch} Testset')\n",
2266
+ " loop.set_postfix(loss=loss.item())\n",
2267
+ " print(\"Mean Test Error\", np.mean(mean_test_error))\n",
2268
+ " torch.save(model.state_dict(), \"distilbert_reuse_{}\".format(epoch))"
2269
+ ]
2270
+ },
2271
+ {
2272
+ "cell_type": "code",
2273
+ "execution_count": 48,
2274
+ "id": "fdf37d18",
2275
+ "metadata": {},
2276
+ "outputs": [],
2277
+ "source": [
2278
+ "torch.save(model.state_dict(), \"distilbert_reuse.model\")"
2279
+ ]
2280
+ },
2281
+ {
2282
+ "cell_type": "code",
2283
+ "execution_count": 49,
2284
+ "id": "d1cfded4",
2285
+ "metadata": {},
2286
+ "outputs": [],
2287
+ "source": [
2288
+ "m = ReuseQuestionDistilBERT(mod)\n",
2289
+ "m.load_state_dict(torch.load(\"distilbert_reuse.model\"))\n",
2290
+ "model = m"
2291
+ ]
2292
+ },
2293
+ {
2294
+ "cell_type": "code",
2295
+ "execution_count": 47,
2296
+ "id": "233bdc18",
2297
+ "metadata": {},
2298
+ "outputs": [
2299
+ {
2300
+ "name": "stderr",
2301
+ "output_type": "stream",
2302
+ "text": [
2303
+ "100%|鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅| 2500/2500 [02:51<00:00, 14.59it/s]"
2304
+ ]
2305
+ },
2306
+ {
2307
+ "name": "stdout",
2308
+ "output_type": "stream",
2309
+ "text": [
2310
+ "Mean EM: 0.5178\n",
2311
+ "Mean F-1: 0.6671140689626448\n"
2312
+ ]
2313
+ },
2314
+ {
2315
+ "name": "stderr",
2316
+ "output_type": "stream",
2317
+ "text": [
2318
+ "\n"
2319
+ ]
2320
+ }
2321
+ ],
2322
+ "source": [
2323
+ "eval_test_set(model, tokenizer, test_loader, device)"
2324
+ ]
2325
+ },
2326
+ {
2327
+ "cell_type": "code",
2328
+ "execution_count": null,
2329
+ "id": "0fb1ce9e",
2330
+ "metadata": {},
2331
+ "outputs": [],
2332
+ "source": []
2333
+ }
2334
+ ],
2335
+ "metadata": {
2336
+ "kernelspec": {
2337
+ "display_name": "Python 3.10.8 ('venv': venv)",
2338
+ "language": "python",
2339
+ "name": "python3"
2340
+ },
2341
+ "language_info": {
2342
+ "codemirror_mode": {
2343
+ "name": "ipython",
2344
+ "version": 3
2345
+ },
2346
+ "file_extension": ".py",
2347
+ "mimetype": "text/x-python",
2348
+ "name": "python",
2349
+ "nbconvert_exporter": "python",
2350
+ "pygments_lexer": "ipython3",
2351
+ "version": "3.10.8"
2352
+ },
2353
+ "toc": {
2354
+ "base_numbering": 1,
2355
+ "nav_menu": {},
2356
+ "number_sections": true,
2357
+ "sideBar": true,
2358
+ "skip_h1_title": false,
2359
+ "title_cell": "Table of Contents",
2360
+ "title_sidebar": "Contents",
2361
+ "toc_cell": false,
2362
+ "toc_position": {},
2363
+ "toc_section_display": true,
2364
+ "toc_window_display": false
2365
+ },
2366
+ "varInspector": {
2367
+ "cols": {
2368
+ "lenName": 16,
2369
+ "lenType": 16,
2370
+ "lenVar": 40
2371
+ },
2372
+ "kernels_config": {
2373
+ "python": {
2374
+ "delete_cmd_postfix": "",
2375
+ "delete_cmd_prefix": "del ",
2376
+ "library": "var_list.py",
2377
+ "varRefreshCmd": "print(var_dic_list())"
2378
+ },
2379
+ "r": {
2380
+ "delete_cmd_postfix": ") ",
2381
+ "delete_cmd_prefix": "rm(",
2382
+ "library": "var_list.r",
2383
+ "varRefreshCmd": "cat(var_dic_list()) "
2384
+ }
2385
+ },
2386
+ "types_to_exclude": [
2387
+ "module",
2388
+ "function",
2389
+ "builtin_function_or_method",
2390
+ "instance",
2391
+ "_Feature"
2392
+ ],
2393
+ "window_display": false
2394
+ },
2395
+ "vscode": {
2396
+ "interpreter": {
2397
+ "hash": "85bf9c14e9ba73b783ed1274d522bec79eb0b2b739090180d8ce17bb11aff4aa"
2398
+ }
2399
+ }
2400
+ },
2401
+ "nbformat": 4,
2402
+ "nbformat_minor": 5
2403
+ }
requirements.txt ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.3.0
2
+ aiohttp==3.8.3
3
+ aiosignal==1.2.0
4
+ altair==4.2.0
5
+ apache-beam>=2.41.0
6
+ argon2-cffi==21.3.0
7
+ argon2-cffi-bindings==21.2.0
8
+ asttokens==2.0.8
9
+ async-timeout==4.0.2
10
+ attrs==22.1.0
11
+ autopep8==1.7.0
12
+ backcall==0.2.0
13
+ beautifulsoup4==4.11.1
14
+ bleach==5.0.1
15
+ blinker==1.5
16
+ cachetools==5.2.0
17
+ certifi==2022.9.24
18
+ cffi==1.15.1
19
+ charset-normalizer==2.1.1
20
+ click==8.1.3
21
+ cloudpickle==2.2.0
22
+ commonmark==0.9.1
23
+ contourpy==1.0.5
24
+ crcmod==1.7
25
+ cycler==0.11.0
26
+ datasets==2.5.2
27
+ debugpy==1.6.3
28
+ decorator==5.1.1
29
+ defusedxml==0.7.1
30
+ dill==0.3.1.1
31
+ docopt==0.6.2
32
+ entrypoints==0.4
33
+ executing==1.1.0
34
+ fastavro==1.6.1
35
+ fastjsonschema==2.16.2
36
+ filelock==3.8.0
37
+ fonttools==4.37.4
38
+ frozenlist==1.3.1
39
+ fsspec==2022.8.2
40
+ gitdb==4.0.10
41
+ GitPython==3.1.29
42
+ google-auth==2.13.0
43
+ google-auth-oauthlib==0.4.6
44
+ grpcio==1.49.1
45
+ hdfs==2.7.0
46
+ httplib2==0.20.4
47
+ huggingface-hub==0.10.0
48
+ idna==3.4
49
+ importlib-metadata==5.1.0
50
+ ipykernel==6.16.0
51
+ ipython==8.5.0
52
+ ipython-genutils==0.2.0
53
+ ipywidgets==8.0.2
54
+ jedi==0.18.1
55
+ Jinja2==3.1.2
56
+ joblib==1.2.0
57
+ jsonschema==4.16.0
58
+ jupyter==1.0.0
59
+ jupyter-console==6.4.4
60
+ jupyter-contrib-core==0.4.0
61
+ jupyter-contrib-nbextensions==0.5.1
62
+ jupyter-highlight-selected-word==0.2.0
63
+ jupyter-latex-envs==1.4.6
64
+ jupyter-nbextensions-configurator==0.5.0
65
+ jupyter_client==7.3.5
66
+ jupyter_core==4.11.2
67
+ jupyterlab-pygments==0.2.2
68
+ jupyterlab-widgets==3.0.3
69
+ kiwisolver==1.4.4
70
+ lesscpy==0.15.1
71
+ lxml==4.9.1
72
+ Markdown==3.4.1
73
+ MarkupSafe==2.1.1
74
+ matplotlib==3.6.1
75
+ matplotlib-inline==0.1.6
76
+ mistune==2.0.4
77
+ multidict==6.0.2
78
+ multiprocess==0.70.9
79
+ mwparserfromhell==0.6.4
80
+ nbclient==0.7.0
81
+ nbconvert==7.2.1
82
+ nbformat==5.6.1
83
+ nest-asyncio==1.5.6
84
+ notebook==6.4.12
85
+ numpy==1.22.4
86
+ oauthlib==3.2.2
87
+ orjson==3.9.7
88
+ packaging==21.3
89
+ pandas==1.5.0
90
+ pandocfilters==1.5.0
91
+ parso==0.8.3
92
+ pexpect==4.8.0
93
+ pickleshare==0.7.5
94
+ Pillow==9.2.0
95
+ ply==3.11
96
+ prettytable==3.4.1
97
+ prometheus-client==0.14.1
98
+ prompt-toolkit==3.0.31
99
+ proto-plus==1.22.1
100
+ protobuf==3.19.6
101
+ psutil==5.9.2
102
+ ptyprocess==0.7.0
103
+ pure-eval==0.2.2
104
+ pyarrow==7.0.0
105
+ pyasn1==0.4.8
106
+ pyasn1-modules==0.2.8
107
+ pycodestyle==2.9.1
108
+ pycparser==2.21
109
+ pydeck==0.8.0
110
+ pydot==1.4.2
111
+ Pygments==2.13.0
112
+ pymongo==3.12.3
113
+ Pympler==1.0.1
114
+ pyparsing==3.0.9
115
+ pyrsistent==0.18.1
116
+ python-dateutil==2.8.2
117
+ pytz==2022.4
118
+ pytz-deprecation-shim==0.1.0.post0
119
+ PyYAML==6.0
120
+ pyzmq==24.0.1
121
+ qtconsole==5.3.2
122
+ QtPy==2.2.1
123
+ regex==2022.9.13
124
+ requests==2.28.1
125
+ requests-oauthlib==1.3.1
126
+ responses==0.18.0
127
+ rich==12.6.0
128
+ rsa==4.9
129
+ scikit-learn==1.1.2
130
+ scipy==1.9.1
131
+ semver==2.13.0
132
+ Send2Trash==1.8.0
133
+ six==1.16.0
134
+ smmap==5.0.0
135
+ soupsieve==2.3.2.post1
136
+ stack-data==0.5.1
137
+ streamlit==1.15.2
138
+ tensorboard==2.10.1
139
+ tensorboard-data-server==0.6.1
140
+ tensorboard-plugin-wit==1.8.1
141
+ terminado==0.16.0
142
+ threadpoolctl==3.1.0
143
+ tinycss2==1.1.1
144
+ tokenizers==0.12.1
145
+ toml==0.10.2
146
+ toolz==0.12.0
147
+ torch==1.12.1
148
+ torchaudio==0.12.1
149
+ torchsummary==1.5.1
150
+ torchtest==0.5
151
+ torchvision==0.13.1
152
+ tornado==6.2
153
+ tqdm==4.64.1
154
+ traitlets==5.4.0
155
+ transformers==4.22.2
156
+ typing_extensions==4.4.0
157
+ tzdata==2022.7
158
+ tzlocal==4.2
159
+ urllib3==1.26.12
160
+ validators==0.20.0
161
+ watchdog==2.2.0
162
+ wcwidth==0.2.5
163
+ webencodings==0.5.1
164
+ Werkzeug==2.2.2
165
+ widgetsnbextension==4.0.3
166
+ xxhash==3.0.0
167
+ yarl==1.8.1
168
+ zipp==3.11.0
util.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import numpy as np
3
+ from prettytable import PrettyTable
4
+ from tqdm import tqdm
5
+ import torch
6
+
7
+
8
+ def normalize_text(s):
9
+ """
10
+ Removes articles and punctuation, and standardizing whitespace are all typical text processing steps.
11
+ Copied from: https://qa.fastforwardlabs.com/no%20answer/null%20threshold/bert/distilbert/exact%20match/f1/robust%20predictions/2020/06/09/Evaluating_BERT_on_SQuAD.html#Metrics-for-QA
12
+ :param s: string to clean
13
+ :return: cleaned string
14
+ """
15
+ import string, re
16
+
17
+ def remove_articles(text):
18
+ regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
19
+ return re.sub(regex, " ", text)
20
+
21
+ def white_space_fix(text):
22
+ return " ".join(text.split())
23
+
24
+ def remove_punc(text):
25
+ exclude = set(string.punctuation)
26
+ return "".join(ch for ch in text if ch not in exclude)
27
+
28
+ def lower(text):
29
+ return text.lower()
30
+
31
+ return white_space_fix(remove_articles(remove_punc(lower(s))))
32
+
33
+
34
+ def compute_exact_match(prediction, truth):
35
+ """
36
+ Returns true if the predicted is an exact match, else False
37
+ Retrieved from: https://qa.fastforwardlabs.com/no%20answer/null%20threshold/bert/distilbert/exact%20match/f1/robust%20predictions/2020/06/09/Evaluating_BERT_on_SQuAD.html#Metrics-for-QA
38
+ :param prediction: predicted answer
39
+ :param truth: ground truth
40
+ :return: 1 if exact match, else 0
41
+ """
42
+ return int(normalize_text(prediction) == normalize_text(truth))
43
+
44
+
45
+ def compute_f1(prediction, truth):
46
+ """
47
+ Computes the F-1 score of a prediction, based on the tokens
48
+ Retrieved from: https://qa.fastforwardlabs.com/no%20answer/null%20threshold/bert/distilbert/exact%20match/f1/robust%20predictions/2020/06/09/Evaluating_BERT_on_SQuAD.html#Metrics-for-QA
49
+ :param prediction: predicted answer
50
+ :param truth: ground truth
51
+ :return: the f-1 score of the prediction
52
+ """
53
+ pred_tokens = normalize_text(prediction).split()
54
+ truth_tokens = normalize_text(truth).split()
55
+
56
+ # if either the prediction or the truth is no-answer then f1 = 1 if they agree, 0 otherwise
57
+ if len(pred_tokens) == 0 or len(truth_tokens) == 0:
58
+ return int(pred_tokens == truth_tokens)
59
+
60
+ # get tokens that are in the prediction and gt
61
+ common_tokens = set(pred_tokens) & set(truth_tokens)
62
+
63
+ # if there are no common tokens then f1 = 0
64
+ if len(common_tokens) == 0:
65
+ return 0
66
+
67
+ # calculate precision and recall
68
+ prec = len(common_tokens) / len(pred_tokens)
69
+ rec = len(common_tokens) / len(truth_tokens)
70
+
71
+ return 2 * (prec * rec) / (prec + rec)
72
+
73
+ def eval_test_set(model, tokenizer, test_loader, device):
74
+ """
75
+ Calculates the mean EM and mean F-1 score on the test set
76
+ :param model: pytorch model
77
+ :param tokenizer: tokenizer used to encode the samples
78
+ :param test_loader: dataloader object with test data
79
+ :param device: device the model is on
80
+ """
81
+ mean_em = []
82
+ mean_f1 = []
83
+ model.to(device)
84
+ model.eval()
85
+ for batch in tqdm(test_loader):
86
+ # get test data and transfer to device
87
+ input_ids = batch['input_ids'].to(device)
88
+ attention_mask = batch['attention_mask'].to(device)
89
+ start = batch['start_positions'].to(device)
90
+ end = batch['end_positions'].to(device)
91
+
92
+ # predict
93
+ outputs = model(input_ids, attention_mask=attention_mask, start_positions=start, end_positions=end)
94
+
95
+ # iterate over samples, calculate EM and F-1 for all
96
+ for input_i, s, e, trues, truee in zip(input_ids, outputs['start_logits'], outputs['end_logits'], start, end):
97
+ # get predicted start and end logits (maximum score)
98
+ start_logits = torch.argmax(s)
99
+ end_logits = torch.argmax(e)
100
+
101
+ # get predicted answer as string
102
+ ans_tokens = input_i[start_logits: end_logits + 1]
103
+ answer_tokens = tokenizer.convert_ids_to_tokens(ans_tokens, skip_special_tokens=True)
104
+ predicted = tokenizer.convert_tokens_to_string(answer_tokens)
105
+
106
+ # get ground truth as string
107
+ ans_tokens = input_i[trues: truee + 1]
108
+ answer_tokens = tokenizer.convert_ids_to_tokens(ans_tokens, skip_special_tokens=True)
109
+ true = tokenizer.convert_tokens_to_string(answer_tokens)
110
+
111
+ # compute score
112
+ em_score = compute_exact_match(predicted, true)
113
+ f1_score = compute_f1(predicted, true)
114
+ mean_em.append(em_score)
115
+ mean_f1.append(f1_score)
116
+ print("Mean EM: ", np.mean(mean_em))
117
+ print("Mean F-1: ", np.mean(mean_f1))
118
+
119
+ def count_parameters(model):
120
+ """
121
+ This function prints statistic regarding the trainable parameters
122
+ :param model: pytorch model
123
+ :return: parameters to be fine-tuned
124
+ """
125
+ table = PrettyTable(["Modules", "Parameters"])
126
+ total_params = 0
127
+ for name, parameter in model.named_parameters():
128
+ if not parameter.requires_grad: continue
129
+ params = parameter.numel()
130
+ table.add_row([name, params])
131
+ total_params += params
132
+ print(table)
133
+ print(f"Total Trainable Params: {total_params}")
134
+ return total_params