Upload 8 files
Browse files- application.py +70 -0
- distilbert.ipynb +981 -0
- distilbert.py +175 -0
- load_data.ipynb +1209 -0
- qa_model.py +532 -0
- question_answering.ipynb +2403 -0
- requirements.txt +168 -0
- util.py +134 -0
application.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from transformers import DistilBertTokenizer, DistilBertForMaskedLM
|
5 |
+
|
6 |
+
from qa_model import ReuseQuestionDistilBERT
|
7 |
+
|
8 |
+
@st.cache(allow_output_mutation=True)
|
9 |
+
def load_model():
|
10 |
+
mod = DistilBertForMaskedLM.from_pretrained("distilbert-base-uncased").distilbert
|
11 |
+
m = ReuseQuestionDistilBERT(mod)
|
12 |
+
m.load_state_dict(torch.load("distilbert_reuse.model", map_location=torch.device('cpu')))
|
13 |
+
model = m
|
14 |
+
del mod
|
15 |
+
del m
|
16 |
+
tokenizer = DistilBertTokenizer.from_pretrained('qa_tokenizer')
|
17 |
+
return model, tokenizer
|
18 |
+
|
19 |
+
|
20 |
+
def get_answer(question, text, tokenizer, model):
|
21 |
+
question = [question.strip()]
|
22 |
+
text = [text.strip()]
|
23 |
+
|
24 |
+
inputs = tokenizer(
|
25 |
+
question,
|
26 |
+
text,
|
27 |
+
max_length=512,
|
28 |
+
truncation="only_second",
|
29 |
+
padding="max_length",
|
30 |
+
)
|
31 |
+
input_ids = torch.tensor(inputs['input_ids'])
|
32 |
+
outputs = model(input_ids, attention_mask=torch.tensor(inputs['attention_mask']), start_positions=None, end_positions=None)
|
33 |
+
|
34 |
+
start = torch.argmax(outputs['start_logits'])
|
35 |
+
end = torch.argmax(outputs['end_logits'])
|
36 |
+
|
37 |
+
ans_tokens = input_ids[0][start: end + 1]
|
38 |
+
|
39 |
+
answer_tokens = tokenizer.convert_ids_to_tokens(ans_tokens, skip_special_tokens=True)
|
40 |
+
predicted = tokenizer.convert_tokens_to_string(answer_tokens)
|
41 |
+
return predicted
|
42 |
+
|
43 |
+
|
44 |
+
def main():
|
45 |
+
st.set_page_config(page_title="Question Answering Tool", page_icon=":mag_right:")
|
46 |
+
|
47 |
+
st.write("# Question Answering Tool \n"
|
48 |
+
"This tool will help you find answers to your questions about the text you provide. \n"
|
49 |
+
"Please enter your question and the text you want to search in the boxes below.")
|
50 |
+
model, tokenizer = load_model()
|
51 |
+
|
52 |
+
with st.form("qa_form"):
|
53 |
+
# define a streamlit textarea
|
54 |
+
text = st.text_area("Enter your text here", on_change=None)
|
55 |
+
|
56 |
+
# define a streamlit input
|
57 |
+
question = st.text_input("Enter your question here")
|
58 |
+
|
59 |
+
if st.form_submit_button("Submit"):
|
60 |
+
data_load_state = st.text('Let me think about that...')
|
61 |
+
# call the function to get the answer
|
62 |
+
answer = get_answer(question, text, tokenizer, model)
|
63 |
+
# display the answer
|
64 |
+
if answer == "":
|
65 |
+
data_load_state.text("Sorry but I don't know the answer to that question")
|
66 |
+
else:
|
67 |
+
data_load_state.text(answer)
|
68 |
+
|
69 |
+
|
70 |
+
main()
|
distilbert.ipynb
ADDED
@@ -0,0 +1,981 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "47700837",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"# DistilBERT Base Model\n",
|
9 |
+
"The following contains the code to create and train a DistilBERT model using the Huggingface library. It works quite well for a moderate amount of data, but the runtime increases quite drastically with data.\n",
|
10 |
+
"\n",
|
11 |
+
"I decided to take the pretrained model after all, still, creating the model myself was quite interesting!"
|
12 |
+
]
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"cell_type": "code",
|
16 |
+
"execution_count": 4,
|
17 |
+
"id": "c09fa906",
|
18 |
+
"metadata": {},
|
19 |
+
"outputs": [],
|
20 |
+
"source": [
|
21 |
+
"from pathlib import Path\n",
|
22 |
+
"import torch\n",
|
23 |
+
"import time\n",
|
24 |
+
"from pathlib import Path\n",
|
25 |
+
"from transformers import DistilBertTokenizerFast\n",
|
26 |
+
"import os\n",
|
27 |
+
"from transformers import DistilBertConfig\n",
|
28 |
+
"from transformers import DistilBertForMaskedLM\n",
|
29 |
+
"from tokenizers import BertWordPieceTokenizer\n",
|
30 |
+
"from tqdm.auto import tqdm\n",
|
31 |
+
"from torch.optim import AdamW\n",
|
32 |
+
"import torchtest\n",
|
33 |
+
"from transformers import pipeline\n",
|
34 |
+
"\n",
|
35 |
+
"\n",
|
36 |
+
"from distilbert import test_model\n",
|
37 |
+
"from distilbert import Dataset\n",
|
38 |
+
"\n",
|
39 |
+
"import numpy as np"
|
40 |
+
]
|
41 |
+
},
|
42 |
+
{
|
43 |
+
"cell_type": "markdown",
|
44 |
+
"id": "3b773fac",
|
45 |
+
"metadata": {},
|
46 |
+
"source": [
|
47 |
+
"## Tokeniser\n",
|
48 |
+
"We need a way to convert the strings we get as the input to numerical tokens, that we can give to the neual network. Hence, we take a BertWorkPieceTokenizer (works for DistilBERT too) and create tokens from our words."
|
49 |
+
]
|
50 |
+
},
|
51 |
+
{
|
52 |
+
"cell_type": "code",
|
53 |
+
"execution_count": 5,
|
54 |
+
"id": "24277c5b",
|
55 |
+
"metadata": {},
|
56 |
+
"outputs": [
|
57 |
+
{
|
58 |
+
"name": "stdout",
|
59 |
+
"output_type": "stream",
|
60 |
+
"text": [
|
61 |
+
"Tokeniser created\n"
|
62 |
+
]
|
63 |
+
}
|
64 |
+
],
|
65 |
+
"source": [
|
66 |
+
"fit_new_tokenizer = True\n",
|
67 |
+
"\n",
|
68 |
+
"if fit_new_tokenizer:\n",
|
69 |
+
" paths = [str(x) for x in Path('data/original').glob('**/*.txt')]\n",
|
70 |
+
"\n",
|
71 |
+
" tokenizer = BertWordPieceTokenizer(\n",
|
72 |
+
" clean_text=True,\n",
|
73 |
+
" handle_chinese_chars=False,\n",
|
74 |
+
" strip_accents=False,\n",
|
75 |
+
" lowercase=True\n",
|
76 |
+
" )\n",
|
77 |
+
" print(\"Tokeniser created\")"
|
78 |
+
]
|
79 |
+
},
|
80 |
+
{
|
81 |
+
"cell_type": "code",
|
82 |
+
"execution_count": 6,
|
83 |
+
"id": "beacf3e3",
|
84 |
+
"metadata": {},
|
85 |
+
"outputs": [
|
86 |
+
{
|
87 |
+
"name": "stdout",
|
88 |
+
"output_type": "stream",
|
89 |
+
"text": [
|
90 |
+
"\n",
|
91 |
+
"\n",
|
92 |
+
"\n"
|
93 |
+
]
|
94 |
+
}
|
95 |
+
],
|
96 |
+
"source": [
|
97 |
+
"# fit the tokenizer\n",
|
98 |
+
"if fit_new_tokenizer:\n",
|
99 |
+
" tokenizer.train(files=paths[:10], vocab_size=30_000, min_frequency=2,\n",
|
100 |
+
" limit_alphabet=1000, wordpieces_prefix='##',\n",
|
101 |
+
" special_tokens=['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]'])"
|
102 |
+
]
|
103 |
+
},
|
104 |
+
{
|
105 |
+
"cell_type": "code",
|
106 |
+
"execution_count": 7,
|
107 |
+
"id": "0d462cc5",
|
108 |
+
"metadata": {},
|
109 |
+
"outputs": [
|
110 |
+
{
|
111 |
+
"ename": "FileExistsError",
|
112 |
+
"evalue": "[Errno 17] File exists: './tokeniser'",
|
113 |
+
"output_type": "error",
|
114 |
+
"traceback": [
|
115 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
116 |
+
"\u001b[0;31mFileExistsError\u001b[0m Traceback (most recent call last)",
|
117 |
+
"Cell \u001b[0;32mIn [7], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m fit_new_tokenizer:\n\u001b[0;32m----> 2\u001b[0m os\u001b[38;5;241m.\u001b[39mmkdir(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m./tokeniser\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 3\u001b[0m tokenizer\u001b[38;5;241m.\u001b[39msave_model(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtokeniser\u001b[39m\u001b[38;5;124m'\u001b[39m)\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mTokeniser saved\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
|
118 |
+
"\u001b[0;31mFileExistsError\u001b[0m: [Errno 17] File exists: './tokeniser'"
|
119 |
+
]
|
120 |
+
}
|
121 |
+
],
|
122 |
+
"source": [
|
123 |
+
"if fit_new_tokenizer:\n",
|
124 |
+
" os.mkdir('./tokeniser')\n",
|
125 |
+
" tokenizer.save_model('tokeniser')\n",
|
126 |
+
" print(\"Tokeniser saved\")"
|
127 |
+
]
|
128 |
+
},
|
129 |
+
{
|
130 |
+
"cell_type": "markdown",
|
131 |
+
"id": "7eaa1667",
|
132 |
+
"metadata": {},
|
133 |
+
"source": [
|
134 |
+
"After having created a basic tokeniser, we use the model to initialise a DistilBert tokenizer, that we need for the model architecture later on. We save the tokeniser separately."
|
135 |
+
]
|
136 |
+
},
|
137 |
+
{
|
138 |
+
"cell_type": "code",
|
139 |
+
"execution_count": 8,
|
140 |
+
"id": "f4dd0684",
|
141 |
+
"metadata": {},
|
142 |
+
"outputs": [
|
143 |
+
{
|
144 |
+
"data": {
|
145 |
+
"text/plain": [
|
146 |
+
"('distilbert_tokenizer/tokenizer_config.json',\n",
|
147 |
+
" 'distilbert_tokenizer/special_tokens_map.json',\n",
|
148 |
+
" 'distilbert_tokenizer/vocab.txt',\n",
|
149 |
+
" 'distilbert_tokenizer/added_tokens.json',\n",
|
150 |
+
" 'distilbert_tokenizer/tokenizer.json')"
|
151 |
+
]
|
152 |
+
},
|
153 |
+
"execution_count": 8,
|
154 |
+
"metadata": {},
|
155 |
+
"output_type": "execute_result"
|
156 |
+
}
|
157 |
+
],
|
158 |
+
"source": [
|
159 |
+
"tokenizer = DistilBertTokenizerFast.from_pretrained('tokeniser', max_len=512)\n",
|
160 |
+
"tokenizer.save_pretrained(\"distilbert_tokenizer\")"
|
161 |
+
]
|
162 |
+
},
|
163 |
+
{
|
164 |
+
"cell_type": "markdown",
|
165 |
+
"id": "bfcafcde",
|
166 |
+
"metadata": {},
|
167 |
+
"source": [
|
168 |
+
"### Testing\n",
|
169 |
+
"We now test the created tokenizer. We take a simple example and tokenise the input. It can be seen that we add a special token in the beginning and end ('CLS' and 'SEP'), which is how the BERT model was defined.\n",
|
170 |
+
"\n",
|
171 |
+
"When we translate the input back, we can see that we get the same, except for the first and last token. Also, we can see that questionmarks and commas are encoded separately."
|
172 |
+
]
|
173 |
+
},
|
174 |
+
{
|
175 |
+
"cell_type": "code",
|
176 |
+
"execution_count": 9,
|
177 |
+
"id": "37e7f6a8",
|
178 |
+
"metadata": {},
|
179 |
+
"outputs": [
|
180 |
+
{
|
181 |
+
"name": "stdout",
|
182 |
+
"output_type": "stream",
|
183 |
+
"text": [
|
184 |
+
"{'input_ids': [2, 10958, 16, 2175, 1993, 1965, 35, 3], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1]}\n"
|
185 |
+
]
|
186 |
+
}
|
187 |
+
],
|
188 |
+
"source": [
|
189 |
+
"tokens = tokenizer('Hello, how are you?')\n",
|
190 |
+
"print(tokens)"
|
191 |
+
]
|
192 |
+
},
|
193 |
+
{
|
194 |
+
"cell_type": "code",
|
195 |
+
"execution_count": 10,
|
196 |
+
"id": "bbd0c4b1",
|
197 |
+
"metadata": {},
|
198 |
+
"outputs": [
|
199 |
+
{
|
200 |
+
"data": {
|
201 |
+
"text/plain": [
|
202 |
+
"'[CLS] hello, how are you? [SEP]'"
|
203 |
+
]
|
204 |
+
},
|
205 |
+
"execution_count": 10,
|
206 |
+
"metadata": {},
|
207 |
+
"output_type": "execute_result"
|
208 |
+
}
|
209 |
+
],
|
210 |
+
"source": [
|
211 |
+
"tokenizer.decode(tokens['input_ids'])"
|
212 |
+
]
|
213 |
+
},
|
214 |
+
{
|
215 |
+
"cell_type": "code",
|
216 |
+
"execution_count": 11,
|
217 |
+
"id": "4ab6e506",
|
218 |
+
"metadata": {},
|
219 |
+
"outputs": [
|
220 |
+
{
|
221 |
+
"name": "stdout",
|
222 |
+
"output_type": "stream",
|
223 |
+
"text": [
|
224 |
+
"[CLS]\n",
|
225 |
+
"hello\n",
|
226 |
+
",\n",
|
227 |
+
"how\n",
|
228 |
+
"are\n",
|
229 |
+
"you\n",
|
230 |
+
"?\n",
|
231 |
+
"[SEP]\n"
|
232 |
+
]
|
233 |
+
}
|
234 |
+
],
|
235 |
+
"source": [
|
236 |
+
"for tok in tokens['input_ids']:\n",
|
237 |
+
" print(tokenizer.decode(tok))"
|
238 |
+
]
|
239 |
+
},
|
240 |
+
{
|
241 |
+
"cell_type": "code",
|
242 |
+
"execution_count": 12,
|
243 |
+
"id": "c75d3255",
|
244 |
+
"metadata": {},
|
245 |
+
"outputs": [],
|
246 |
+
"source": [
|
247 |
+
"assert len(tokenizer.vocab) == 30_000"
|
248 |
+
]
|
249 |
+
},
|
250 |
+
{
|
251 |
+
"cell_type": "markdown",
|
252 |
+
"id": "dd114355",
|
253 |
+
"metadata": {},
|
254 |
+
"source": [
|
255 |
+
"## Dataset\n",
|
256 |
+
"We now define a function to mask some of the tokens. In particular, we create a Dataset class, that automates loading the data and tokenising it for us. Lastly, we use a DataLoader to load the data step by step into memory.\n",
|
257 |
+
"\n",
|
258 |
+
"The big problem with the limited resources we have is memory. In particular, I am loading the data sequentially, file by file, keeping track how many samples have been read. Shuffling wouldn't work here (it would also not make a lot of sense for this dataset)."
|
259 |
+
]
|
260 |
+
},
|
261 |
+
{
|
262 |
+
"cell_type": "code",
|
263 |
+
"execution_count": 10,
|
264 |
+
"id": "bff9ea54",
|
265 |
+
"metadata": {},
|
266 |
+
"outputs": [],
|
267 |
+
"source": [
|
268 |
+
"# create dataset and dataloader \n",
|
269 |
+
"dataset = Dataset(paths = [str(x) for x in Path('data/original').glob('**/*.txt')][50:70], tokenizer=tokenizer)\n",
|
270 |
+
"loader = torch.utils.data.DataLoader(dataset, batch_size=8)\n",
|
271 |
+
"\n",
|
272 |
+
"test_dataset = Dataset(paths = [str(x) for x in Path('data/original').glob('**/*.txt')][10:12], tokenizer=tokenizer)\n",
|
273 |
+
"test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4)"
|
274 |
+
]
|
275 |
+
},
|
276 |
+
{
|
277 |
+
"cell_type": "markdown",
|
278 |
+
"id": "6bbe6e63",
|
279 |
+
"metadata": {},
|
280 |
+
"source": [
|
281 |
+
"### Testing\n",
|
282 |
+
"The randomisation makes it a bit difficult to test. But altogether, we see that the input ids, masks and labels have the same shape. Also, as we mask 15% of the samples, when decoding a given sample, we can see that some samples are now '[MASK]'."
|
283 |
+
]
|
284 |
+
},
|
285 |
+
{
|
286 |
+
"cell_type": "code",
|
287 |
+
"execution_count": 11,
|
288 |
+
"id": "436ab745",
|
289 |
+
"metadata": {},
|
290 |
+
"outputs": [],
|
291 |
+
"source": [
|
292 |
+
"i = iter(dataset)"
|
293 |
+
]
|
294 |
+
},
|
295 |
+
{
|
296 |
+
"cell_type": "code",
|
297 |
+
"execution_count": 12,
|
298 |
+
"id": "330e599d",
|
299 |
+
"metadata": {},
|
300 |
+
"outputs": [
|
301 |
+
{
|
302 |
+
"name": "stdout",
|
303 |
+
"output_type": "stream",
|
304 |
+
"text": [
|
305 |
+
"Passed\n"
|
306 |
+
]
|
307 |
+
}
|
308 |
+
],
|
309 |
+
"source": [
|
310 |
+
"for j in range(10):\n",
|
311 |
+
" sample = next(i)\n",
|
312 |
+
" \n",
|
313 |
+
" input_ids = sample['input_ids']\n",
|
314 |
+
" attention_masks = sample['attention_mask']\n",
|
315 |
+
" labels = sample['labels']\n",
|
316 |
+
" \n",
|
317 |
+
" # check if the dimensions are right\n",
|
318 |
+
" assert input_ids.shape[0] == (512)\n",
|
319 |
+
" assert attention_masks.shape[0] == (512)\n",
|
320 |
+
" assert labels.shape[0] == (512)\n",
|
321 |
+
" \n",
|
322 |
+
" # if the input ids are not masked, the labels are the same as the input ids\n",
|
323 |
+
" assert np.array_equal(input_ids[input_ids != 4].numpy(),labels[input_ids != 4].numpy())\n",
|
324 |
+
" # input ids are zero if the attention masks are zero\n",
|
325 |
+
" assert np.all(input_ids[attention_masks == 0].numpy()==0)\n",
|
326 |
+
" # check if input contains masked tokens (we can't guarantee this 100% but this will apply) most likely\n",
|
327 |
+
" assert np.any(input_ids.numpy() == 4)\n",
|
328 |
+
"print(\"Passed\")"
|
329 |
+
]
|
330 |
+
},
|
331 |
+
{
|
332 |
+
"cell_type": "markdown",
|
333 |
+
"id": "08db6d22",
|
334 |
+
"metadata": {},
|
335 |
+
"source": [
|
336 |
+
"## Model\n",
|
337 |
+
"In the following section, we intialise and train a model."
|
338 |
+
]
|
339 |
+
},
|
340 |
+
{
|
341 |
+
"cell_type": "code",
|
342 |
+
"execution_count": 13,
|
343 |
+
"id": "7803bda6",
|
344 |
+
"metadata": {},
|
345 |
+
"outputs": [],
|
346 |
+
"source": [
|
347 |
+
"config = DistilBertConfig(\n",
|
348 |
+
" vocab_size=30000,\n",
|
349 |
+
" max_position_embeddings=514\n",
|
350 |
+
")"
|
351 |
+
]
|
352 |
+
},
|
353 |
+
{
|
354 |
+
"cell_type": "code",
|
355 |
+
"execution_count": 14,
|
356 |
+
"id": "8ca03f6a",
|
357 |
+
"metadata": {},
|
358 |
+
"outputs": [],
|
359 |
+
"source": [
|
360 |
+
"model = DistilBertForMaskedLM(config)"
|
361 |
+
]
|
362 |
+
},
|
363 |
+
{
|
364 |
+
"cell_type": "code",
|
365 |
+
"execution_count": 15,
|
366 |
+
"id": "4da22bff",
|
367 |
+
"metadata": {
|
368 |
+
"scrolled": false
|
369 |
+
},
|
370 |
+
"outputs": [
|
371 |
+
{
|
372 |
+
"name": "stderr",
|
373 |
+
"output_type": "stream",
|
374 |
+
"text": [
|
375 |
+
"/home/sanju/anaconda3/envs/myenv/lib/python3.10/site-packages/torch/cuda/__init__.py:83: UserWarning: CUDA initialization: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero. (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:109.)\n",
|
376 |
+
" return torch._C._cuda_getDeviceCount() > 0\n"
|
377 |
+
]
|
378 |
+
},
|
379 |
+
{
|
380 |
+
"data": {
|
381 |
+
"text/plain": [
|
382 |
+
"DistilBertForMaskedLM(\n",
|
383 |
+
" (activation): GELUActivation()\n",
|
384 |
+
" (distilbert): DistilBertModel(\n",
|
385 |
+
" (embeddings): Embeddings(\n",
|
386 |
+
" (word_embeddings): Embedding(30000, 768, padding_idx=0)\n",
|
387 |
+
" (position_embeddings): Embedding(514, 768)\n",
|
388 |
+
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
389 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
390 |
+
" )\n",
|
391 |
+
" (transformer): Transformer(\n",
|
392 |
+
" (layer): ModuleList(\n",
|
393 |
+
" (0): TransformerBlock(\n",
|
394 |
+
" (attention): MultiHeadSelfAttention(\n",
|
395 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
396 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
397 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
398 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
399 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
400 |
+
" )\n",
|
401 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
402 |
+
" (ffn): FFN(\n",
|
403 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
404 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
405 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
406 |
+
" (activation): GELUActivation()\n",
|
407 |
+
" )\n",
|
408 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
409 |
+
" )\n",
|
410 |
+
" (1): TransformerBlock(\n",
|
411 |
+
" (attention): MultiHeadSelfAttention(\n",
|
412 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
413 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
414 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
415 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
416 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
417 |
+
" )\n",
|
418 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
419 |
+
" (ffn): FFN(\n",
|
420 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
421 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
422 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
423 |
+
" (activation): GELUActivation()\n",
|
424 |
+
" )\n",
|
425 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
426 |
+
" )\n",
|
427 |
+
" (2): TransformerBlock(\n",
|
428 |
+
" (attention): MultiHeadSelfAttention(\n",
|
429 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
430 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
431 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
432 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
433 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
434 |
+
" )\n",
|
435 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
436 |
+
" (ffn): FFN(\n",
|
437 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
438 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
439 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
440 |
+
" (activation): GELUActivation()\n",
|
441 |
+
" )\n",
|
442 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
443 |
+
" )\n",
|
444 |
+
" (3): TransformerBlock(\n",
|
445 |
+
" (attention): MultiHeadSelfAttention(\n",
|
446 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
447 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
448 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
449 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
450 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
451 |
+
" )\n",
|
452 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
453 |
+
" (ffn): FFN(\n",
|
454 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
455 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
456 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
457 |
+
" (activation): GELUActivation()\n",
|
458 |
+
" )\n",
|
459 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
460 |
+
" )\n",
|
461 |
+
" (4): TransformerBlock(\n",
|
462 |
+
" (attention): MultiHeadSelfAttention(\n",
|
463 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
464 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
465 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
466 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
467 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
468 |
+
" )\n",
|
469 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
470 |
+
" (ffn): FFN(\n",
|
471 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
472 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
473 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
474 |
+
" (activation): GELUActivation()\n",
|
475 |
+
" )\n",
|
476 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
477 |
+
" )\n",
|
478 |
+
" (5): TransformerBlock(\n",
|
479 |
+
" (attention): MultiHeadSelfAttention(\n",
|
480 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
481 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
482 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
483 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
484 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
485 |
+
" )\n",
|
486 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
487 |
+
" (ffn): FFN(\n",
|
488 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
489 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
490 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
491 |
+
" (activation): GELUActivation()\n",
|
492 |
+
" )\n",
|
493 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
494 |
+
" )\n",
|
495 |
+
" )\n",
|
496 |
+
" )\n",
|
497 |
+
" )\n",
|
498 |
+
" (vocab_transform): Linear(in_features=768, out_features=768, bias=True)\n",
|
499 |
+
" (vocab_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
500 |
+
" (vocab_projector): Linear(in_features=768, out_features=30000, bias=True)\n",
|
501 |
+
" (mlm_loss_fct): CrossEntropyLoss()\n",
|
502 |
+
")"
|
503 |
+
]
|
504 |
+
},
|
505 |
+
"execution_count": 15,
|
506 |
+
"metadata": {},
|
507 |
+
"output_type": "execute_result"
|
508 |
+
}
|
509 |
+
],
|
510 |
+
"source": [
|
511 |
+
"# if we have a GPU - train on gpu\n",
|
512 |
+
"device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
|
513 |
+
"model.to(device)"
|
514 |
+
]
|
515 |
+
},
|
516 |
+
{
|
517 |
+
"cell_type": "markdown",
|
518 |
+
"id": "6fb8c2e2",
|
519 |
+
"metadata": {},
|
520 |
+
"source": [
|
521 |
+
"### Testing the model\n",
|
522 |
+
"I stumbled across some Medium articles on how to test DeepLearning models beforehand \n",
|
523 |
+
"* https://thenerdstation.medium.com/how-to-unit-test-machine-learning-code-57cf6fd81765: the package is however deprecated\n",
|
524 |
+
"* https://towardsdatascience.com/testing-your-pytorch-models-with-torcheck-cb689ecbc08c: released a package (torcheck)\n",
|
525 |
+
"* https://github.com/suriyadeepan/torchtest: I found this package, which is the PyTorch version of the first one and is still maintained.\n",
|
526 |
+
"\n",
|
527 |
+
"Essentially, testing a model is inherently difficult, because we do not know the result in the beginning. Still, the following four conditions should be satisfied in every model (see second reference above):\n",
|
528 |
+
"1. The parameters should change during training (if they are not frozen).\n",
|
529 |
+
"2. The parameters should not change if they are frozen.\n",
|
530 |
+
"3. The range of the ouput should be in a predefined range.\n",
|
531 |
+
"4. The parameters should never contain NaN. The same goes for the outputs too.\n",
|
532 |
+
"\n",
|
533 |
+
"I tried using the packages, but they do not trivially apply for models with multiple inputs (we have input ids and attention masks). The following is partly adapted from the torchtest package (https://github.com/suriyadeepan/torchtest/blob/master/torchtest/torchtest.py)."
|
534 |
+
]
|
535 |
+
},
|
536 |
+
{
|
537 |
+
"cell_type": "code",
|
538 |
+
"execution_count": 16,
|
539 |
+
"id": "cfd33fa1",
|
540 |
+
"metadata": {},
|
541 |
+
"outputs": [],
|
542 |
+
"source": [
|
543 |
+
"# get smaller dataset\n",
|
544 |
+
"test_ds = Dataset(paths = [str(x) for x in Path('data/original').glob('**/*.txt')][:2], tokenizer=tokenizer)\n",
|
545 |
+
"test_ds_loader = torch.utils.data.DataLoader(test_ds, batch_size=2)\n",
|
546 |
+
"optim=torch.optim.Adam(model.parameters())"
|
547 |
+
]
|
548 |
+
},
|
549 |
+
{
|
550 |
+
"cell_type": "code",
|
551 |
+
"execution_count": 17,
|
552 |
+
"id": "907db815",
|
553 |
+
"metadata": {},
|
554 |
+
"outputs": [
|
555 |
+
{
|
556 |
+
"name": "stdout",
|
557 |
+
"output_type": "stream",
|
558 |
+
"text": [
|
559 |
+
"Passed\n"
|
560 |
+
]
|
561 |
+
}
|
562 |
+
],
|
563 |
+
"source": [
|
564 |
+
"from distilbert import test_model\n",
|
565 |
+
"\n",
|
566 |
+
"test_model(model, optim, test_ds_loader, device)"
|
567 |
+
]
|
568 |
+
},
|
569 |
+
{
|
570 |
+
"cell_type": "markdown",
|
571 |
+
"id": "c02c9c4b",
|
572 |
+
"metadata": {},
|
573 |
+
"source": [
|
574 |
+
"### Training the model\n",
|
575 |
+
"We use AdamW as the optimiser and train for 10 epochs.\n",
|
576 |
+
"\n",
|
577 |
+
"Taking the whole dataset, takes about 100 hours per epoch for me, so I wasn't able to do that."
|
578 |
+
]
|
579 |
+
},
|
580 |
+
{
|
581 |
+
"cell_type": "code",
|
582 |
+
"execution_count": 18,
|
583 |
+
"id": "178914f8",
|
584 |
+
"metadata": {},
|
585 |
+
"outputs": [
|
586 |
+
{
|
587 |
+
"data": {
|
588 |
+
"text/plain": [
|
589 |
+
"DistilBertForMaskedLM(\n",
|
590 |
+
" (activation): GELUActivation()\n",
|
591 |
+
" (distilbert): DistilBertModel(\n",
|
592 |
+
" (embeddings): Embeddings(\n",
|
593 |
+
" (word_embeddings): Embedding(30000, 768, padding_idx=0)\n",
|
594 |
+
" (position_embeddings): Embedding(514, 768)\n",
|
595 |
+
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
596 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
597 |
+
" )\n",
|
598 |
+
" (transformer): Transformer(\n",
|
599 |
+
" (layer): ModuleList(\n",
|
600 |
+
" (0): TransformerBlock(\n",
|
601 |
+
" (attention): MultiHeadSelfAttention(\n",
|
602 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
603 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
604 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
605 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
606 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
607 |
+
" )\n",
|
608 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
609 |
+
" (ffn): FFN(\n",
|
610 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
611 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
612 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
613 |
+
" (activation): GELUActivation()\n",
|
614 |
+
" )\n",
|
615 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
616 |
+
" )\n",
|
617 |
+
" (1): TransformerBlock(\n",
|
618 |
+
" (attention): MultiHeadSelfAttention(\n",
|
619 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
620 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
621 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
622 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
623 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
624 |
+
" )\n",
|
625 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
626 |
+
" (ffn): FFN(\n",
|
627 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
628 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
629 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
630 |
+
" (activation): GELUActivation()\n",
|
631 |
+
" )\n",
|
632 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
633 |
+
" )\n",
|
634 |
+
" (2): TransformerBlock(\n",
|
635 |
+
" (attention): MultiHeadSelfAttention(\n",
|
636 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
637 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
638 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
639 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
640 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
641 |
+
" )\n",
|
642 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
643 |
+
" (ffn): FFN(\n",
|
644 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
645 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
646 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
647 |
+
" (activation): GELUActivation()\n",
|
648 |
+
" )\n",
|
649 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
650 |
+
" )\n",
|
651 |
+
" (3): TransformerBlock(\n",
|
652 |
+
" (attention): MultiHeadSelfAttention(\n",
|
653 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
654 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
655 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
656 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
657 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
658 |
+
" )\n",
|
659 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
660 |
+
" (ffn): FFN(\n",
|
661 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
662 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
663 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
664 |
+
" (activation): GELUActivation()\n",
|
665 |
+
" )\n",
|
666 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
667 |
+
" )\n",
|
668 |
+
" (4): TransformerBlock(\n",
|
669 |
+
" (attention): MultiHeadSelfAttention(\n",
|
670 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
671 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
672 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
673 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
674 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
675 |
+
" )\n",
|
676 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
677 |
+
" (ffn): FFN(\n",
|
678 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
679 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
680 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
681 |
+
" (activation): GELUActivation()\n",
|
682 |
+
" )\n",
|
683 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
684 |
+
" )\n",
|
685 |
+
" (5): TransformerBlock(\n",
|
686 |
+
" (attention): MultiHeadSelfAttention(\n",
|
687 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
688 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
689 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
690 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
691 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
692 |
+
" )\n",
|
693 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
694 |
+
" (ffn): FFN(\n",
|
695 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
696 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
697 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
698 |
+
" (activation): GELUActivation()\n",
|
699 |
+
" )\n",
|
700 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
701 |
+
" )\n",
|
702 |
+
" )\n",
|
703 |
+
" )\n",
|
704 |
+
" )\n",
|
705 |
+
" (vocab_transform): Linear(in_features=768, out_features=768, bias=True)\n",
|
706 |
+
" (vocab_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
707 |
+
" (vocab_projector): Linear(in_features=768, out_features=30000, bias=True)\n",
|
708 |
+
" (mlm_loss_fct): CrossEntropyLoss()\n",
|
709 |
+
")"
|
710 |
+
]
|
711 |
+
},
|
712 |
+
"execution_count": 18,
|
713 |
+
"metadata": {},
|
714 |
+
"output_type": "execute_result"
|
715 |
+
}
|
716 |
+
],
|
717 |
+
"source": [
|
718 |
+
"model = DistilBertForMaskedLM(config)\n",
|
719 |
+
"# if we have a GPU - train on gpu\n",
|
720 |
+
"device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
|
721 |
+
"model.to(device)"
|
722 |
+
]
|
723 |
+
},
|
724 |
+
{
|
725 |
+
"cell_type": "code",
|
726 |
+
"execution_count": 19,
|
727 |
+
"id": "bb6532be",
|
728 |
+
"metadata": {},
|
729 |
+
"outputs": [],
|
730 |
+
"source": [
|
731 |
+
"# we use AdamW as the optimiser\n",
|
732 |
+
"optim = AdamW(model.parameters(), lr=1e-4)"
|
733 |
+
]
|
734 |
+
},
|
735 |
+
{
|
736 |
+
"cell_type": "code",
|
737 |
+
"execution_count": 20,
|
738 |
+
"id": "2fd5d609",
|
739 |
+
"metadata": {},
|
740 |
+
"outputs": [
|
741 |
+
{
|
742 |
+
"data": {
|
743 |
+
"application/vnd.jupyter.widget-view+json": {
|
744 |
+
"model_id": "c3386dc78c65490a96d11ade635d522f",
|
745 |
+
"version_major": 2,
|
746 |
+
"version_minor": 0
|
747 |
+
},
|
748 |
+
"text/plain": [
|
749 |
+
" 0%| | 0/23750 [00:00<?, ?it/s]"
|
750 |
+
]
|
751 |
+
},
|
752 |
+
"metadata": {},
|
753 |
+
"output_type": "display_data"
|
754 |
+
}
|
755 |
+
],
|
756 |
+
"source": [
|
757 |
+
"epochs = 10\n",
|
758 |
+
"\n",
|
759 |
+
"for epoch in range(epochs):\n",
|
760 |
+
" loop = tqdm(loader, leave=True)\n",
|
761 |
+
" \n",
|
762 |
+
" # set model to training mode\n",
|
763 |
+
" model.train()\n",
|
764 |
+
" losses = []\n",
|
765 |
+
" \n",
|
766 |
+
" # iterate over dataset\n",
|
767 |
+
" for batch in loop:\n",
|
768 |
+
" optim.zero_grad()\n",
|
769 |
+
" \n",
|
770 |
+
" # copy input to device\n",
|
771 |
+
" input_ids = batch['input_ids'].to(device)\n",
|
772 |
+
" attention_mask = batch['attention_mask'].to(device)\n",
|
773 |
+
" labels = batch['labels'].to(device)\n",
|
774 |
+
" \n",
|
775 |
+
" # predict\n",
|
776 |
+
" outputs = model(input_ids, attention_mask=attention_mask, labels=labels)\n",
|
777 |
+
" \n",
|
778 |
+
" # update weights\n",
|
779 |
+
" loss = outputs.loss\n",
|
780 |
+
" loss.backward()\n",
|
781 |
+
" \n",
|
782 |
+
" optim.step()\n",
|
783 |
+
" \n",
|
784 |
+
" # output current loss\n",
|
785 |
+
" loop.set_description(f'Epoch {epoch}')\n",
|
786 |
+
" loop.set_postfix(loss=loss.item())\n",
|
787 |
+
" losses.append(loss.item())\n",
|
788 |
+
" \n",
|
789 |
+
" del input_ids\n",
|
790 |
+
" del attention_mask\n",
|
791 |
+
" del labels\n",
|
792 |
+
" \n",
|
793 |
+
" print(\"Mean Training Loss\", np.mean(losses))\n",
|
794 |
+
" losses = []\n",
|
795 |
+
" loop = tqdm(test_loader, leave=True)\n",
|
796 |
+
" \n",
|
797 |
+
" # set model to evaluation mode\n",
|
798 |
+
" model.eval()\n",
|
799 |
+
" \n",
|
800 |
+
" # iterate over dataset\n",
|
801 |
+
" for batch in loop:\n",
|
802 |
+
" # copy input to device\n",
|
803 |
+
" input_ids = batch['input_ids'].to(device)\n",
|
804 |
+
" attention_mask = batch['attention_mask'].to(device)\n",
|
805 |
+
" labels = batch['labels'].to(device)\n",
|
806 |
+
" \n",
|
807 |
+
" # predict\n",
|
808 |
+
" outputs = model(input_ids, attention_mask=attention_mask, labels=labels)\n",
|
809 |
+
" \n",
|
810 |
+
" # update weights\n",
|
811 |
+
" loss = outputs.loss\n",
|
812 |
+
" \n",
|
813 |
+
" # output current loss\n",
|
814 |
+
" loop.set_description(f'Epoch {epoch}')\n",
|
815 |
+
" loop.set_postfix(loss=loss.item())\n",
|
816 |
+
" losses.append(loss.item())\n",
|
817 |
+
" \n",
|
818 |
+
" del input_ids\n",
|
819 |
+
" del attention_mask\n",
|
820 |
+
" del labels\n",
|
821 |
+
" print(\"Mean Test Loss\", np.mean(losses))"
|
822 |
+
]
|
823 |
+
},
|
824 |
+
{
|
825 |
+
"cell_type": "code",
|
826 |
+
"execution_count": 22,
|
827 |
+
"id": "03c23c3e",
|
828 |
+
"metadata": {},
|
829 |
+
"outputs": [],
|
830 |
+
"source": [
|
831 |
+
"# save the pretrained model\n",
|
832 |
+
"torch.save(model, \"distilbert.model\")"
|
833 |
+
]
|
834 |
+
},
|
835 |
+
{
|
836 |
+
"cell_type": "code",
|
837 |
+
"execution_count": 25,
|
838 |
+
"id": "9b18d3e3",
|
839 |
+
"metadata": {},
|
840 |
+
"outputs": [],
|
841 |
+
"source": [
|
842 |
+
"model = torch.load(\"distilbert.model\")"
|
843 |
+
]
|
844 |
+
},
|
845 |
+
{
|
846 |
+
"cell_type": "markdown",
|
847 |
+
"id": "e6ad94db",
|
848 |
+
"metadata": {},
|
849 |
+
"source": [
|
850 |
+
"### Testing\n",
|
851 |
+
"Huggingface provides a library to quickly be able to see what word the model would predict for our masked token."
|
852 |
+
]
|
853 |
+
},
|
854 |
+
{
|
855 |
+
"cell_type": "code",
|
856 |
+
"execution_count": 27,
|
857 |
+
"id": "7c8582d2",
|
858 |
+
"metadata": {},
|
859 |
+
"outputs": [],
|
860 |
+
"source": [
|
861 |
+
"fill = pipeline(\"fill-mask\", model='distilbert', config=config, tokenizer='distilbert_tokenizer')"
|
862 |
+
]
|
863 |
+
},
|
864 |
+
{
|
865 |
+
"cell_type": "code",
|
866 |
+
"execution_count": 28,
|
867 |
+
"id": "d309e57f",
|
868 |
+
"metadata": {},
|
869 |
+
"outputs": [
|
870 |
+
{
|
871 |
+
"data": {
|
872 |
+
"text/plain": [
|
873 |
+
"[{'score': 0.19730663299560547,\n",
|
874 |
+
" 'token': 2965,\n",
|
875 |
+
" 'token_str': 'change',\n",
|
876 |
+
" 'sequence': 'it seems important to tackle the climate change.'},\n",
|
877 |
+
" {'score': 0.12946806848049164,\n",
|
878 |
+
" 'token': 5215,\n",
|
879 |
+
" 'token_str': 'crisis',\n",
|
880 |
+
" 'sequence': 'it seems important to tackle the climate crisis.'},\n",
|
881 |
+
" {'score': 0.05868387222290039,\n",
|
882 |
+
" 'token': 3688,\n",
|
883 |
+
" 'token_str': 'issues',\n",
|
884 |
+
" 'sequence': 'it seems important to tackle the climate issues.'},\n",
|
885 |
+
" {'score': 0.047418754547834396,\n",
|
886 |
+
" 'token': 3406,\n",
|
887 |
+
" 'token_str': 'issue',\n",
|
888 |
+
" 'sequence': 'it seems important to tackle the climate issue.'},\n",
|
889 |
+
" {'score': 0.027855267748236656,\n",
|
890 |
+
" 'token': 2629,\n",
|
891 |
+
" 'token_str': 'here',\n",
|
892 |
+
" 'sequence': 'it seems important to tackle the climate here.'}]"
|
893 |
+
]
|
894 |
+
},
|
895 |
+
"execution_count": 28,
|
896 |
+
"metadata": {},
|
897 |
+
"output_type": "execute_result"
|
898 |
+
}
|
899 |
+
],
|
900 |
+
"source": [
|
901 |
+
"fill(f'It seems important to tackle the climate {fill.tokenizer.mask_token}.')"
|
902 |
+
]
|
903 |
+
},
|
904 |
+
{
|
905 |
+
"cell_type": "code",
|
906 |
+
"execution_count": null,
|
907 |
+
"id": "94e3e623",
|
908 |
+
"metadata": {},
|
909 |
+
"outputs": [],
|
910 |
+
"source": []
|
911 |
+
}
|
912 |
+
],
|
913 |
+
"metadata": {
|
914 |
+
"kernelspec": {
|
915 |
+
"display_name": "Python 3.10.8 ('venv': venv)",
|
916 |
+
"language": "python",
|
917 |
+
"name": "python3"
|
918 |
+
},
|
919 |
+
"language_info": {
|
920 |
+
"codemirror_mode": {
|
921 |
+
"name": "ipython",
|
922 |
+
"version": 3
|
923 |
+
},
|
924 |
+
"file_extension": ".py",
|
925 |
+
"mimetype": "text/x-python",
|
926 |
+
"name": "python",
|
927 |
+
"nbconvert_exporter": "python",
|
928 |
+
"pygments_lexer": "ipython3",
|
929 |
+
"version": "3.10.16"
|
930 |
+
},
|
931 |
+
"toc": {
|
932 |
+
"base_numbering": 1,
|
933 |
+
"nav_menu": {},
|
934 |
+
"number_sections": true,
|
935 |
+
"sideBar": true,
|
936 |
+
"skip_h1_title": false,
|
937 |
+
"title_cell": "Table of Contents",
|
938 |
+
"title_sidebar": "Contents",
|
939 |
+
"toc_cell": false,
|
940 |
+
"toc_position": {},
|
941 |
+
"toc_section_display": true,
|
942 |
+
"toc_window_display": false
|
943 |
+
},
|
944 |
+
"varInspector": {
|
945 |
+
"cols": {
|
946 |
+
"lenName": 16,
|
947 |
+
"lenType": 16,
|
948 |
+
"lenVar": 40
|
949 |
+
},
|
950 |
+
"kernels_config": {
|
951 |
+
"python": {
|
952 |
+
"delete_cmd_postfix": "",
|
953 |
+
"delete_cmd_prefix": "del ",
|
954 |
+
"library": "var_list.py",
|
955 |
+
"varRefreshCmd": "print(var_dic_list())"
|
956 |
+
},
|
957 |
+
"r": {
|
958 |
+
"delete_cmd_postfix": ") ",
|
959 |
+
"delete_cmd_prefix": "rm(",
|
960 |
+
"library": "var_list.r",
|
961 |
+
"varRefreshCmd": "cat(var_dic_list()) "
|
962 |
+
}
|
963 |
+
},
|
964 |
+
"types_to_exclude": [
|
965 |
+
"module",
|
966 |
+
"function",
|
967 |
+
"builtin_function_or_method",
|
968 |
+
"instance",
|
969 |
+
"_Feature"
|
970 |
+
],
|
971 |
+
"window_display": false
|
972 |
+
},
|
973 |
+
"vscode": {
|
974 |
+
"interpreter": {
|
975 |
+
"hash": "85bf9c14e9ba73b783ed1274d522bec79eb0b2b739090180d8ce17bb11aff4aa"
|
976 |
+
}
|
977 |
+
}
|
978 |
+
},
|
979 |
+
"nbformat": 4,
|
980 |
+
"nbformat_minor": 5
|
981 |
+
}
|
distilbert.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
class Dataset(torch.utils.data.Dataset):
|
4 |
+
"""
|
5 |
+
This class loads and preprocesses the given text data
|
6 |
+
"""
|
7 |
+
def __init__(self, paths, tokenizer):
|
8 |
+
"""
|
9 |
+
This function initialises the object. It takes the given paths and tokeniser.
|
10 |
+
"""
|
11 |
+
# the last file might not have 10000 samples, which makes it difficult to get the total length of the ds
|
12 |
+
self.paths = paths[:len(paths)-1]
|
13 |
+
self.tokenizer = tokenizer
|
14 |
+
self.data = self.read_file(self.paths[0])
|
15 |
+
self.current_file = 1
|
16 |
+
self.remaining = len(self.data)
|
17 |
+
self.encodings = self.get_encodings(self.data)
|
18 |
+
|
19 |
+
def __len__(self):
|
20 |
+
"""
|
21 |
+
returns the lenght of the ds
|
22 |
+
"""
|
23 |
+
return 10000*len(self.paths)
|
24 |
+
|
25 |
+
def read_file(self, path):
|
26 |
+
"""
|
27 |
+
reads a given file
|
28 |
+
"""
|
29 |
+
with open(path, 'r', encoding='utf-8') as f:
|
30 |
+
lines = f.read().split('\n')
|
31 |
+
return lines
|
32 |
+
|
33 |
+
def get_encodings(self, lines_all):
|
34 |
+
"""
|
35 |
+
Creates encodings for a given text input
|
36 |
+
"""
|
37 |
+
# tokenise all text
|
38 |
+
batch = self.tokenizer(lines_all, max_length=512, padding='max_length', truncation=True)
|
39 |
+
|
40 |
+
# Ground Truth
|
41 |
+
labels = torch.tensor(batch['input_ids'])
|
42 |
+
# Attention Masks
|
43 |
+
mask = torch.tensor(batch['attention_mask'])
|
44 |
+
|
45 |
+
# Input to be masked
|
46 |
+
input_ids = labels.detach().clone()
|
47 |
+
rand = torch.rand(input_ids.shape)
|
48 |
+
|
49 |
+
# with a probability of 15%, mask a given word, leave out CLS, SEP and PAD
|
50 |
+
mask_arr = (rand < .15) * (input_ids != 0) * (input_ids != 2) * (input_ids != 3)
|
51 |
+
# assign token 4 (=MASK)
|
52 |
+
input_ids[mask_arr] = 4
|
53 |
+
|
54 |
+
return {'input_ids':input_ids, 'attention_mask':mask, 'labels':labels}
|
55 |
+
|
56 |
+
def __getitem__(self, i):
|
57 |
+
"""
|
58 |
+
returns item i
|
59 |
+
Note: do not use shuffling for this dataset
|
60 |
+
"""
|
61 |
+
# if we have looked at all items in the file - take next
|
62 |
+
if self.remaining == 0:
|
63 |
+
self.data = self.read_file(self.paths[self.current_file])
|
64 |
+
self.current_file += 1
|
65 |
+
self.remaining = len(self.data)
|
66 |
+
self.encodings = self.get_encodings(self.data)
|
67 |
+
|
68 |
+
# if we are at the end of the dataset, start over again
|
69 |
+
if self.current_file == len(self.paths):
|
70 |
+
self.current_file = 0
|
71 |
+
|
72 |
+
self.remaining -= 1
|
73 |
+
return {key: tensor[i%10000] for key, tensor in self.encodings.items()}
|
74 |
+
|
75 |
+
def test_model(model, optim, test_ds_loader, device):
|
76 |
+
"""
|
77 |
+
This function tests whether the parameters of the model that are frozen change, the ones that are not frozen do change,
|
78 |
+
and whether any parameters become NaN or Inf
|
79 |
+
:param model: model to be tested
|
80 |
+
:param optim: optimiser used for training
|
81 |
+
:param test_ds_loader: dataset to perform the forward pass on
|
82 |
+
:param device: current device
|
83 |
+
:raises Exception: if any of the above conditions are not met
|
84 |
+
"""
|
85 |
+
## Check if non-frozen parameters changed and frozen ones did not
|
86 |
+
|
87 |
+
# get initial parameters to check against
|
88 |
+
params = [ np for np in model.named_parameters() if np[1].requires_grad ]
|
89 |
+
initial_params = [ (name, p.clone()) for (name, p) in params ]
|
90 |
+
|
91 |
+
params_frozen = [ np for np in model.named_parameters() if not np[1].requires_grad ]
|
92 |
+
initial_params_frozen = [ (name, p.clone()) for (name, p) in params_frozen ]
|
93 |
+
|
94 |
+
optim.zero_grad()
|
95 |
+
|
96 |
+
# get data
|
97 |
+
batch = next(iter(test_ds_loader))
|
98 |
+
|
99 |
+
input_ids = batch['input_ids'].to(device)
|
100 |
+
attention_mask = batch['attention_mask'].to(device)
|
101 |
+
labels = batch['labels'].to(device)
|
102 |
+
|
103 |
+
# forward pass and backpropagation
|
104 |
+
outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
|
105 |
+
loss = outputs.loss
|
106 |
+
loss.backward()
|
107 |
+
optim.step()
|
108 |
+
|
109 |
+
# check if variables have changed
|
110 |
+
for (_, p0), (name, p1) in zip(initial_params, params):
|
111 |
+
# check different than initial
|
112 |
+
try:
|
113 |
+
assert not torch.equal(p0.to(device), p1.to(device))
|
114 |
+
except AssertionError:
|
115 |
+
raise Exception(
|
116 |
+
"{var_name} {msg}".format(
|
117 |
+
var_name=name,
|
118 |
+
msg='did not change!'
|
119 |
+
)
|
120 |
+
)
|
121 |
+
# check not NaN
|
122 |
+
try:
|
123 |
+
assert not torch.isnan(p1).byte().any()
|
124 |
+
except AssertionError:
|
125 |
+
raise Exception(
|
126 |
+
"{var_name} {msg}".format(
|
127 |
+
var_name=name,
|
128 |
+
msg='is NaN!'
|
129 |
+
)
|
130 |
+
)
|
131 |
+
# check finite
|
132 |
+
try:
|
133 |
+
assert torch.isfinite(p1).byte().all()
|
134 |
+
except AssertionError:
|
135 |
+
raise Exception(
|
136 |
+
"{var_name} {msg}".format(
|
137 |
+
var_name=name,
|
138 |
+
msg='is Inf!'
|
139 |
+
)
|
140 |
+
)
|
141 |
+
|
142 |
+
# check that frozen weights have not changed
|
143 |
+
for (_, p0), (name, p1) in zip(initial_params_frozen, params_frozen):
|
144 |
+
# should be the same
|
145 |
+
try:
|
146 |
+
assert torch.equal(p0.to(device), p1.to(device))
|
147 |
+
except AssertionError:
|
148 |
+
raise Exception(
|
149 |
+
"{var_name} {msg}".format(
|
150 |
+
var_name=name,
|
151 |
+
msg='changed!'
|
152 |
+
)
|
153 |
+
)
|
154 |
+
# check not NaN
|
155 |
+
try:
|
156 |
+
assert not torch.isnan(p1).byte().any()
|
157 |
+
except AssertionError:
|
158 |
+
raise Exception(
|
159 |
+
"{var_name} {msg}".format(
|
160 |
+
var_name=name,
|
161 |
+
msg='is NaN!'
|
162 |
+
)
|
163 |
+
)
|
164 |
+
|
165 |
+
# check finite numbers
|
166 |
+
try:
|
167 |
+
assert torch.isfinite(p1).byte().all()
|
168 |
+
except AssertionError:
|
169 |
+
raise Exception(
|
170 |
+
"{var_name} {msg}".format(
|
171 |
+
var_name=name,
|
172 |
+
msg='is Inf!'
|
173 |
+
)
|
174 |
+
)
|
175 |
+
print("Passed")
|
load_data.ipynb
ADDED
@@ -0,0 +1,1209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "12d87b30",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"# Load Data\n",
|
9 |
+
"This notebook loads and preproceses all necessary data, namely the following.\n",
|
10 |
+
"* OpenWebTextCorpus: for base DistilBERT model\n",
|
11 |
+
"* SQuAD datasrt: for Q&A\n",
|
12 |
+
"* Natural Questions (needs to be downloaded externally but is preprocessed here): for Q&A\n",
|
13 |
+
"* HotPotQA: for Q&A"
|
14 |
+
]
|
15 |
+
},
|
16 |
+
{
|
17 |
+
"cell_type": "code",
|
18 |
+
"execution_count": 4,
|
19 |
+
"id": "7c82d7fa",
|
20 |
+
"metadata": {},
|
21 |
+
"outputs": [],
|
22 |
+
"source": [
|
23 |
+
"from tqdm.auto import tqdm\n",
|
24 |
+
"from datasets import load_dataset\n",
|
25 |
+
"import os\n",
|
26 |
+
"import pandas as pd\n",
|
27 |
+
"import random"
|
28 |
+
]
|
29 |
+
},
|
30 |
+
{
|
31 |
+
"cell_type": "markdown",
|
32 |
+
"id": "1737f219",
|
33 |
+
"metadata": {},
|
34 |
+
"source": [
|
35 |
+
"## Distilbert Data\n",
|
36 |
+
"In the following, we download the english openwebtext dataset from huggingface (https://huggingface.co/datasets/openwebtext). The dataset is provided by Aaron Gokaslan and Vanya Cohen from Brown University (https://skylion007.github.io/OpenWebTextCorpus/).\n",
|
37 |
+
"\n",
|
38 |
+
"We first load the data, investigate the structure and write the dataset into files of each 10 000 texts."
|
39 |
+
]
|
40 |
+
},
|
41 |
+
{
|
42 |
+
"cell_type": "code",
|
43 |
+
"execution_count": null,
|
44 |
+
"id": "cce7623c",
|
45 |
+
"metadata": {},
|
46 |
+
"outputs": [],
|
47 |
+
"source": [
|
48 |
+
"ds = load_dataset(\"openwebtext\")"
|
49 |
+
]
|
50 |
+
},
|
51 |
+
{
|
52 |
+
"cell_type": "code",
|
53 |
+
"execution_count": 4,
|
54 |
+
"id": "678a5e86",
|
55 |
+
"metadata": {},
|
56 |
+
"outputs": [
|
57 |
+
{
|
58 |
+
"data": {
|
59 |
+
"text/plain": [
|
60 |
+
"DatasetDict({\n",
|
61 |
+
" train: Dataset({\n",
|
62 |
+
" features: ['text'],\n",
|
63 |
+
" num_rows: 8013769\n",
|
64 |
+
" })\n",
|
65 |
+
"})"
|
66 |
+
]
|
67 |
+
},
|
68 |
+
"execution_count": 4,
|
69 |
+
"metadata": {},
|
70 |
+
"output_type": "execute_result"
|
71 |
+
}
|
72 |
+
],
|
73 |
+
"source": [
|
74 |
+
"# we have a text-only training dataset with 8 million entries\n",
|
75 |
+
"ds"
|
76 |
+
]
|
77 |
+
},
|
78 |
+
{
|
79 |
+
"cell_type": "code",
|
80 |
+
"execution_count": 5,
|
81 |
+
"id": "b141bce7",
|
82 |
+
"metadata": {},
|
83 |
+
"outputs": [],
|
84 |
+
"source": [
|
85 |
+
"# create necessary folders\n",
|
86 |
+
"os.mkdir('data')\n",
|
87 |
+
"os.mkdir('data/original')"
|
88 |
+
]
|
89 |
+
},
|
90 |
+
{
|
91 |
+
"cell_type": "code",
|
92 |
+
"execution_count": null,
|
93 |
+
"id": "ca94f995",
|
94 |
+
"metadata": {},
|
95 |
+
"outputs": [],
|
96 |
+
"source": [
|
97 |
+
"# save text in chunks of 10000 samples\n",
|
98 |
+
"text = []\n",
|
99 |
+
"i = 0\n",
|
100 |
+
"\n",
|
101 |
+
"for sample in tqdm(ds['train']):\n",
|
102 |
+
" # replace all newlines\n",
|
103 |
+
" sample = sample['text'].replace('\\n','')\n",
|
104 |
+
" \n",
|
105 |
+
" # append cleaned sample to all texts\n",
|
106 |
+
" text.append(sample)\n",
|
107 |
+
" \n",
|
108 |
+
" # if we processed 10000 samples, write them to a file and start over\n",
|
109 |
+
" if len(text) == 10000:\n",
|
110 |
+
" with open(f\"data/original/text_{i}.txt\", 'w', encoding='utf-8') as f:\n",
|
111 |
+
" f.write('\\n'.join(text))\n",
|
112 |
+
" text = []\n",
|
113 |
+
" i += 1 \n",
|
114 |
+
"\n",
|
115 |
+
"# write remaining samples to a file\n",
|
116 |
+
"with open(f\"data/original/text_{i}.txt\", 'w', encoding='utf-8') as f:\n",
|
117 |
+
" f.write('\\n'.join(text))"
|
118 |
+
]
|
119 |
+
},
|
120 |
+
{
|
121 |
+
"cell_type": "markdown",
|
122 |
+
"id": "f131dcfc",
|
123 |
+
"metadata": {},
|
124 |
+
"source": [
|
125 |
+
"### Testing\n",
|
126 |
+
"If we load the first file, we should get a file that is 10000 lines long and has one column\n",
|
127 |
+
"\n",
|
128 |
+
"As we do not preprocess the data in any way, but just write the read text into the file, this is all testing necessary"
|
129 |
+
]
|
130 |
+
},
|
131 |
+
{
|
132 |
+
"cell_type": "code",
|
133 |
+
"execution_count": 13,
|
134 |
+
"id": "df50af74",
|
135 |
+
"metadata": {},
|
136 |
+
"outputs": [],
|
137 |
+
"source": [
|
138 |
+
"with open(\"data/original/text_0.txt\", 'r', encoding='utf-8') as f:\n",
|
139 |
+
" lines = f.read().split('\\n')\n",
|
140 |
+
"lines = pd.DataFrame(lines)"
|
141 |
+
]
|
142 |
+
},
|
143 |
+
{
|
144 |
+
"cell_type": "code",
|
145 |
+
"execution_count": 14,
|
146 |
+
"id": "8ddb0085",
|
147 |
+
"metadata": {},
|
148 |
+
"outputs": [
|
149 |
+
{
|
150 |
+
"name": "stdout",
|
151 |
+
"output_type": "stream",
|
152 |
+
"text": [
|
153 |
+
"Passed\n"
|
154 |
+
]
|
155 |
+
}
|
156 |
+
],
|
157 |
+
"source": [
|
158 |
+
"assert lines.shape==(10000,1)\n",
|
159 |
+
"print(\"Passed\")"
|
160 |
+
]
|
161 |
+
},
|
162 |
+
{
|
163 |
+
"cell_type": "markdown",
|
164 |
+
"id": "1a65b268",
|
165 |
+
"metadata": {},
|
166 |
+
"source": [
|
167 |
+
"## SQuAD Data\n",
|
168 |
+
"In the following, we download the SQuAD dataset from huggingface (https://huggingface.co/datasets/squad). It was initially provided by Rajpurkar et al. from Stanford University.\n",
|
169 |
+
"\n",
|
170 |
+
"We again load the dataset and store it in chunks of 1000 into files."
|
171 |
+
]
|
172 |
+
},
|
173 |
+
{
|
174 |
+
"cell_type": "code",
|
175 |
+
"execution_count": 6,
|
176 |
+
"id": "6750ce6e",
|
177 |
+
"metadata": {},
|
178 |
+
"outputs": [
|
179 |
+
{
|
180 |
+
"ename": "AssertionError",
|
181 |
+
"evalue": "",
|
182 |
+
"output_type": "error",
|
183 |
+
"traceback": [
|
184 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
185 |
+
"\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)",
|
186 |
+
"Cell \u001b[0;32mIn [6], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m dataset \u001b[38;5;241m=\u001b[39m load_dataset(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124msquad\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
|
187 |
+
"File \u001b[0;32m~/anaconda3/envs/myenv/lib/python3.10/site-packages/datasets/load.py:1670\u001b[0m, in \u001b[0;36mload_dataset\u001b[0;34m(path, name, data_dir, data_files, split, cache_dir, features, download_config, download_mode, ignore_verifications, keep_in_memory, save_infos, revision, use_auth_token, task, streaming, **config_kwargs)\u001b[0m\n\u001b[1;32m 1667\u001b[0m ignore_verifications \u001b[38;5;241m=\u001b[39m ignore_verifications \u001b[38;5;129;01mor\u001b[39;00m save_infos\n\u001b[1;32m 1669\u001b[0m \u001b[38;5;66;03m# Create a dataset builder\u001b[39;00m\n\u001b[0;32m-> 1670\u001b[0m builder_instance \u001b[38;5;241m=\u001b[39m \u001b[43mload_dataset_builder\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1671\u001b[0m \u001b[43m \u001b[49m\u001b[43mpath\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1672\u001b[0m \u001b[43m \u001b[49m\u001b[43mname\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mname\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1673\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1674\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata_files\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_files\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1675\u001b[0m \u001b[43m \u001b[49m\u001b[43mcache_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mcache_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1676\u001b[0m \u001b[43m \u001b[49m\u001b[43mfeatures\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mfeatures\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1677\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1678\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1679\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1680\u001b[0m \u001b[43m \u001b[49m\u001b[43muse_auth_token\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43muse_auth_token\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1681\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mconfig_kwargs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1682\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1684\u001b[0m \u001b[38;5;66;03m# Return iterable dataset in case of streaming\u001b[39;00m\n\u001b[1;32m 1685\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m streaming:\n",
|
188 |
+
"File \u001b[0;32m~/anaconda3/envs/myenv/lib/python3.10/site-packages/datasets/load.py:1447\u001b[0m, in \u001b[0;36mload_dataset_builder\u001b[0;34m(path, name, data_dir, data_files, cache_dir, features, download_config, download_mode, revision, use_auth_token, **config_kwargs)\u001b[0m\n\u001b[1;32m 1445\u001b[0m download_config \u001b[38;5;241m=\u001b[39m download_config\u001b[38;5;241m.\u001b[39mcopy() \u001b[38;5;28;01mif\u001b[39;00m download_config \u001b[38;5;28;01melse\u001b[39;00m DownloadConfig()\n\u001b[1;32m 1446\u001b[0m download_config\u001b[38;5;241m.\u001b[39muse_auth_token \u001b[38;5;241m=\u001b[39m use_auth_token\n\u001b[0;32m-> 1447\u001b[0m dataset_module \u001b[38;5;241m=\u001b[39m \u001b[43mdataset_module_factory\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1448\u001b[0m \u001b[43m \u001b[49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1449\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1450\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1451\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1452\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1453\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata_files\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_files\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1454\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1456\u001b[0m \u001b[38;5;66;03m# Get dataset builder class from the processing script\u001b[39;00m\n\u001b[1;32m 1457\u001b[0m builder_cls \u001b[38;5;241m=\u001b[39m import_main_class(dataset_module\u001b[38;5;241m.\u001b[39mmodule_path)\n",
|
189 |
+
"File \u001b[0;32m~/anaconda3/envs/myenv/lib/python3.10/site-packages/datasets/load.py:1172\u001b[0m, in \u001b[0;36mdataset_module_factory\u001b[0;34m(path, revision, download_config, download_mode, dynamic_modules_path, data_dir, data_files, **download_kwargs)\u001b[0m\n\u001b[1;32m 1167\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(e1, \u001b[38;5;167;01mFileNotFoundError\u001b[39;00m):\n\u001b[1;32m 1168\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mFileNotFoundError\u001b[39;00m(\n\u001b[1;32m 1169\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCouldn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt find a dataset script at \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mrelative_to_absolute_path(combined_path)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m or any data file in the same directory. \u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1170\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCouldn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt find \u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mpath\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m on the Hugging Face Hub either: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mtype\u001b[39m(e1)\u001b[38;5;241m.\u001b[39m\u001b[38;5;18m__name__\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00me1\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1171\u001b[0m ) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28mNone\u001b[39m\n\u001b[0;32m-> 1172\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m e1 \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28mNone\u001b[39m\n\u001b[1;32m 1173\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 1174\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mFileNotFoundError\u001b[39;00m(\n\u001b[1;32m 1175\u001b[0m \u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mCouldn\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mt find a dataset script at \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mrelative_to_absolute_path(combined_path)\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m or any data file in the same directory.\u001b[39m\u001b[38;5;124m\"\u001b[39m\n\u001b[1;32m 1176\u001b[0m )\n",
|
190 |
+
"File \u001b[0;32m~/anaconda3/envs/myenv/lib/python3.10/site-packages/datasets/load.py:1151\u001b[0m, in \u001b[0;36mdataset_module_factory\u001b[0;34m(path, revision, download_config, download_mode, dynamic_modules_path, data_dir, data_files, **download_kwargs)\u001b[0m\n\u001b[1;32m 1143\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m HubDatasetModuleFactoryWithScript(\n\u001b[1;32m 1144\u001b[0m path,\n\u001b[1;32m 1145\u001b[0m revision\u001b[38;5;241m=\u001b[39mrevision,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 1148\u001b[0m dynamic_modules_path\u001b[38;5;241m=\u001b[39mdynamic_modules_path,\n\u001b[1;32m 1149\u001b[0m )\u001b[38;5;241m.\u001b[39mget_module()\n\u001b[1;32m 1150\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1151\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mHubDatasetModuleFactoryWithoutScript\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1152\u001b[0m \u001b[43m \u001b[49m\u001b[43mpath\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1153\u001b[0m \u001b[43m \u001b[49m\u001b[43mrevision\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mrevision\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1154\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata_dir\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_dir\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1155\u001b[0m \u001b[43m \u001b[49m\u001b[43mdata_files\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdata_files\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1156\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_config\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_config\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1157\u001b[0m \u001b[43m \u001b[49m\u001b[43mdownload_mode\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mdownload_mode\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1158\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241m.\u001b[39mget_module()\n\u001b[1;32m 1159\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mException\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m e1: \u001b[38;5;66;03m# noqa: all the attempts failed, before raising the error we should check if the module is already cached.\u001b[39;00m\n\u001b[1;32m 1160\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m:\n",
|
191 |
+
"File \u001b[0;32m~/anaconda3/envs/myenv/lib/python3.10/site-packages/datasets/load.py:744\u001b[0m, in \u001b[0;36mHubDatasetModuleFactoryWithoutScript.__init__\u001b[0;34m(self, name, revision, data_dir, data_files, download_config, download_mode)\u001b[0m\n\u001b[1;32m 742\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdownload_config \u001b[38;5;241m=\u001b[39m download_config \u001b[38;5;129;01mor\u001b[39;00m DownloadConfig()\n\u001b[1;32m 743\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdownload_mode \u001b[38;5;241m=\u001b[39m download_mode\n\u001b[0;32m--> 744\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mname\u001b[38;5;241m.\u001b[39mcount(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m/\u001b[39m\u001b[38;5;124m\"\u001b[39m) \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 745\u001b[0m increase_load_count(name, resource_type\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mdataset\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
|
192 |
+
"\u001b[0;31mAssertionError\u001b[0m: "
|
193 |
+
]
|
194 |
+
}
|
195 |
+
],
|
196 |
+
"source": [
|
197 |
+
"dataset = load_dataset(\"squad\")"
|
198 |
+
]
|
199 |
+
},
|
200 |
+
{
|
201 |
+
"cell_type": "code",
|
202 |
+
"execution_count": null,
|
203 |
+
"id": "65a7ee23",
|
204 |
+
"metadata": {},
|
205 |
+
"outputs": [
|
206 |
+
{
|
207 |
+
"ename": "",
|
208 |
+
"evalue": "",
|
209 |
+
"output_type": "error",
|
210 |
+
"traceback": [
|
211 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
212 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
213 |
+
]
|
214 |
+
},
|
215 |
+
{
|
216 |
+
"ename": "",
|
217 |
+
"evalue": "",
|
218 |
+
"output_type": "error",
|
219 |
+
"traceback": [
|
220 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
221 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
222 |
+
]
|
223 |
+
}
|
224 |
+
],
|
225 |
+
"source": [
|
226 |
+
"os.mkdir(\"data/training_squad\")\n",
|
227 |
+
"os.mkdir(\"data/test_squad\")"
|
228 |
+
]
|
229 |
+
},
|
230 |
+
{
|
231 |
+
"cell_type": "code",
|
232 |
+
"execution_count": null,
|
233 |
+
"id": "f6ebf63e",
|
234 |
+
"metadata": {},
|
235 |
+
"outputs": [
|
236 |
+
{
|
237 |
+
"ename": "",
|
238 |
+
"evalue": "",
|
239 |
+
"output_type": "error",
|
240 |
+
"traceback": [
|
241 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
242 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
243 |
+
]
|
244 |
+
},
|
245 |
+
{
|
246 |
+
"ename": "",
|
247 |
+
"evalue": "",
|
248 |
+
"output_type": "error",
|
249 |
+
"traceback": [
|
250 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
251 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
252 |
+
]
|
253 |
+
}
|
254 |
+
],
|
255 |
+
"source": [
|
256 |
+
"# we already have a training and test split. Each sample has an id, title, context, question and answers.\n",
|
257 |
+
"dataset"
|
258 |
+
]
|
259 |
+
},
|
260 |
+
{
|
261 |
+
"cell_type": "code",
|
262 |
+
"execution_count": null,
|
263 |
+
"id": "f67ae448",
|
264 |
+
"metadata": {},
|
265 |
+
"outputs": [
|
266 |
+
{
|
267 |
+
"ename": "",
|
268 |
+
"evalue": "",
|
269 |
+
"output_type": "error",
|
270 |
+
"traceback": [
|
271 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
272 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
273 |
+
]
|
274 |
+
},
|
275 |
+
{
|
276 |
+
"ename": "",
|
277 |
+
"evalue": "",
|
278 |
+
"output_type": "error",
|
279 |
+
"traceback": [
|
280 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
281 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
282 |
+
]
|
283 |
+
}
|
284 |
+
],
|
285 |
+
"source": [
|
286 |
+
"# answers are provided like that - we need to extract answer_end for the model\n",
|
287 |
+
"dataset['train']['answers'][0]"
|
288 |
+
]
|
289 |
+
},
|
290 |
+
{
|
291 |
+
"cell_type": "code",
|
292 |
+
"execution_count": null,
|
293 |
+
"id": "101cd650",
|
294 |
+
"metadata": {},
|
295 |
+
"outputs": [
|
296 |
+
{
|
297 |
+
"ename": "",
|
298 |
+
"evalue": "",
|
299 |
+
"output_type": "error",
|
300 |
+
"traceback": [
|
301 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
302 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
303 |
+
]
|
304 |
+
},
|
305 |
+
{
|
306 |
+
"ename": "",
|
307 |
+
"evalue": "",
|
308 |
+
"output_type": "error",
|
309 |
+
"traceback": [
|
310 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
311 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
312 |
+
]
|
313 |
+
}
|
314 |
+
],
|
315 |
+
"source": [
|
316 |
+
"# column contains the split (either train or validation), save_dir is the directory\n",
|
317 |
+
"def save_samples(column, save_dir):\n",
|
318 |
+
" text = []\n",
|
319 |
+
" i = 0\n",
|
320 |
+
"\n",
|
321 |
+
" for sample in tqdm(dataset[column]):\n",
|
322 |
+
" \n",
|
323 |
+
" # preprocess the context and question by removing the newlines\n",
|
324 |
+
" context = sample['context'].replace('\\n','')\n",
|
325 |
+
" question = sample['question'].replace('\\n','')\n",
|
326 |
+
"\n",
|
327 |
+
" # get the answer as text and start character index\n",
|
328 |
+
" answer_text = sample['answers']['text'][0]\n",
|
329 |
+
" answer_start = str(sample['answers']['answer_start'][0])\n",
|
330 |
+
" \n",
|
331 |
+
" text.append([context, question, answer_text, answer_start])\n",
|
332 |
+
"\n",
|
333 |
+
" # we choose chunks of 1000\n",
|
334 |
+
" if len(text) == 1000:\n",
|
335 |
+
" with open(f\"data/{save_dir}/text_{i}.txt\", 'w', encoding='utf-8') as f:\n",
|
336 |
+
" f.write(\"\\n\".join([\"\\t\".join(t) for t in text]))\n",
|
337 |
+
" text = []\n",
|
338 |
+
" i += 1\n",
|
339 |
+
"\n",
|
340 |
+
" # save remaining\n",
|
341 |
+
" with open(f\"data/{save_dir}/text_{i}.txt\", 'w', encoding='utf-8') as f:\n",
|
342 |
+
" f.write(\"\\n\".join([\"\\t\".join(t) for t in text]))\n",
|
343 |
+
"\n",
|
344 |
+
"save_samples(\"train\", \"training_squad\")\n",
|
345 |
+
"save_samples(\"validation\", \"test_squad\")\n",
|
346 |
+
" "
|
347 |
+
]
|
348 |
+
},
|
349 |
+
{
|
350 |
+
"cell_type": "markdown",
|
351 |
+
"id": "67044d13",
|
352 |
+
"metadata": {
|
353 |
+
"collapsed": false,
|
354 |
+
"jupyter": {
|
355 |
+
"outputs_hidden": false
|
356 |
+
}
|
357 |
+
},
|
358 |
+
"source": [
|
359 |
+
"### Testing\n",
|
360 |
+
"If we load a file, we should get a file with 10000 lines and 4 columns\n",
|
361 |
+
"\n",
|
362 |
+
"Also, we want to assure the correct interval. Hence, the second test."
|
363 |
+
]
|
364 |
+
},
|
365 |
+
{
|
366 |
+
"cell_type": "code",
|
367 |
+
"execution_count": null,
|
368 |
+
"id": "446281cf",
|
369 |
+
"metadata": {},
|
370 |
+
"outputs": [
|
371 |
+
{
|
372 |
+
"ename": "",
|
373 |
+
"evalue": "",
|
374 |
+
"output_type": "error",
|
375 |
+
"traceback": [
|
376 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
377 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
378 |
+
]
|
379 |
+
},
|
380 |
+
{
|
381 |
+
"ename": "",
|
382 |
+
"evalue": "",
|
383 |
+
"output_type": "error",
|
384 |
+
"traceback": [
|
385 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
386 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
387 |
+
]
|
388 |
+
}
|
389 |
+
],
|
390 |
+
"source": [
|
391 |
+
"with open(\"data/training_squad/text_0.txt\", 'r', encoding='utf-8') as f:\n",
|
392 |
+
" lines = f.read().split('\\n')\n",
|
393 |
+
" \n",
|
394 |
+
"lines = pd.DataFrame([line.split(\"\\t\") for line in lines], columns=[\"context\", \"question\", \"answer\", \"answer_start\"])"
|
395 |
+
]
|
396 |
+
},
|
397 |
+
{
|
398 |
+
"cell_type": "code",
|
399 |
+
"execution_count": null,
|
400 |
+
"id": "ccd5c650",
|
401 |
+
"metadata": {},
|
402 |
+
"outputs": [
|
403 |
+
{
|
404 |
+
"ename": "",
|
405 |
+
"evalue": "",
|
406 |
+
"output_type": "error",
|
407 |
+
"traceback": [
|
408 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
409 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
410 |
+
]
|
411 |
+
},
|
412 |
+
{
|
413 |
+
"ename": "",
|
414 |
+
"evalue": "",
|
415 |
+
"output_type": "error",
|
416 |
+
"traceback": [
|
417 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
418 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
419 |
+
]
|
420 |
+
}
|
421 |
+
],
|
422 |
+
"source": [
|
423 |
+
"assert lines.shape==(1000,4)\n",
|
424 |
+
"print(\"Passed\")"
|
425 |
+
]
|
426 |
+
},
|
427 |
+
{
|
428 |
+
"cell_type": "code",
|
429 |
+
"execution_count": null,
|
430 |
+
"id": "2c9e4b70",
|
431 |
+
"metadata": {},
|
432 |
+
"outputs": [
|
433 |
+
{
|
434 |
+
"ename": "",
|
435 |
+
"evalue": "",
|
436 |
+
"output_type": "error",
|
437 |
+
"traceback": [
|
438 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
439 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
440 |
+
]
|
441 |
+
},
|
442 |
+
{
|
443 |
+
"ename": "",
|
444 |
+
"evalue": "",
|
445 |
+
"output_type": "error",
|
446 |
+
"traceback": [
|
447 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
448 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
449 |
+
]
|
450 |
+
}
|
451 |
+
],
|
452 |
+
"source": [
|
453 |
+
"# we assert that we have the right interval\n",
|
454 |
+
"for ind, line in lines.iterrows():\n",
|
455 |
+
" sample = line\n",
|
456 |
+
" answer_start = int(sample['answer_start'])\n",
|
457 |
+
" assert sample['context'][answer_start:answer_start+len(sample['answer'])] == sample['answer']\n",
|
458 |
+
"print(\"Passed\")"
|
459 |
+
]
|
460 |
+
},
|
461 |
+
{
|
462 |
+
"cell_type": "markdown",
|
463 |
+
"id": "02265ace",
|
464 |
+
"metadata": {},
|
465 |
+
"source": [
|
466 |
+
"## Natural Questions Dataset\n",
|
467 |
+
"* Download from https://ai.google.com/research/NaturalQuestions via gsutil (the one from huggingface has 134.92GB, the one from google cloud is in archives)\n",
|
468 |
+
"* Use gunzip to get some samples - we then get `.jsonl`files\n",
|
469 |
+
"* The dataset is a lot more messy, as it is just wikipedia articles with all web artifacts\n",
|
470 |
+
" * I cleaned the html tags\n",
|
471 |
+
" * Also I chose a random interval (containing the answer) from the dataset\n",
|
472 |
+
" * We can't send the whole text into the model anyways"
|
473 |
+
]
|
474 |
+
},
|
475 |
+
{
|
476 |
+
"cell_type": "code",
|
477 |
+
"execution_count": null,
|
478 |
+
"id": "f3bce0c1",
|
479 |
+
"metadata": {},
|
480 |
+
"outputs": [
|
481 |
+
{
|
482 |
+
"ename": "",
|
483 |
+
"evalue": "",
|
484 |
+
"output_type": "error",
|
485 |
+
"traceback": [
|
486 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
487 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
488 |
+
]
|
489 |
+
},
|
490 |
+
{
|
491 |
+
"ename": "",
|
492 |
+
"evalue": "",
|
493 |
+
"output_type": "error",
|
494 |
+
"traceback": [
|
495 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
496 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
497 |
+
]
|
498 |
+
}
|
499 |
+
],
|
500 |
+
"source": [
|
501 |
+
"from pathlib import Path\n",
|
502 |
+
"paths = [str(x) for x in Path('data/natural_questions/v1.0/train/').glob('**/*.jsonl')]"
|
503 |
+
]
|
504 |
+
},
|
505 |
+
{
|
506 |
+
"cell_type": "code",
|
507 |
+
"execution_count": null,
|
508 |
+
"id": "e9c58c00",
|
509 |
+
"metadata": {},
|
510 |
+
"outputs": [
|
511 |
+
{
|
512 |
+
"ename": "",
|
513 |
+
"evalue": "",
|
514 |
+
"output_type": "error",
|
515 |
+
"traceback": [
|
516 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
517 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
518 |
+
]
|
519 |
+
},
|
520 |
+
{
|
521 |
+
"ename": "",
|
522 |
+
"evalue": "",
|
523 |
+
"output_type": "error",
|
524 |
+
"traceback": [
|
525 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
526 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
527 |
+
]
|
528 |
+
}
|
529 |
+
],
|
530 |
+
"source": [
|
531 |
+
"os.mkdir(\"data/natural_questions_train\")"
|
532 |
+
]
|
533 |
+
},
|
534 |
+
{
|
535 |
+
"cell_type": "code",
|
536 |
+
"execution_count": null,
|
537 |
+
"id": "0ed7ba6c",
|
538 |
+
"metadata": {},
|
539 |
+
"outputs": [
|
540 |
+
{
|
541 |
+
"ename": "",
|
542 |
+
"evalue": "",
|
543 |
+
"output_type": "error",
|
544 |
+
"traceback": [
|
545 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
546 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
547 |
+
]
|
548 |
+
},
|
549 |
+
{
|
550 |
+
"ename": "",
|
551 |
+
"evalue": "",
|
552 |
+
"output_type": "error",
|
553 |
+
"traceback": [
|
554 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
555 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
556 |
+
]
|
557 |
+
}
|
558 |
+
],
|
559 |
+
"source": [
|
560 |
+
"import re\n",
|
561 |
+
"\n",
|
562 |
+
"# clean html tags\n",
|
563 |
+
"CLEANR = re.compile('<.+?>')\n",
|
564 |
+
"# clean multiple spaces\n",
|
565 |
+
"CLEANMULTSPACE = re.compile('(\\s)+')\n",
|
566 |
+
"\n",
|
567 |
+
"# the function takes an html documents and removes artifacts\n",
|
568 |
+
"def cleanhtml(raw_html):\n",
|
569 |
+
" # tags\n",
|
570 |
+
" cleantext = re.sub(CLEANR, '', raw_html)\n",
|
571 |
+
" # newlines\n",
|
572 |
+
" cleantext = cleantext.replace(\"\\n\", '')\n",
|
573 |
+
" # tabs\n",
|
574 |
+
" cleantext = cleantext.replace(\"\\t\", '')\n",
|
575 |
+
" # character encodings\n",
|
576 |
+
" cleantext = cleantext.replace(\"'\", \"'\")\n",
|
577 |
+
" cleantext = cleantext.replace(\"&\", \"'\")\n",
|
578 |
+
" cleantext = cleantext.replace(\""\", '\"')\n",
|
579 |
+
" # multiple spaces\n",
|
580 |
+
" cleantext = re.sub(CLEANMULTSPACE, ' ', cleantext)\n",
|
581 |
+
" # documents end with this tags, if it is present in the string, cut it off\n",
|
582 |
+
" idx = cleantext.find(\"<!-- NewPP limit\")\n",
|
583 |
+
" if idx > -1:\n",
|
584 |
+
" cleantext = cleantext[:idx]\n",
|
585 |
+
" return cleantext.strip()"
|
586 |
+
]
|
587 |
+
},
|
588 |
+
{
|
589 |
+
"cell_type": "code",
|
590 |
+
"execution_count": null,
|
591 |
+
"id": "66ca19ac",
|
592 |
+
"metadata": {},
|
593 |
+
"outputs": [
|
594 |
+
{
|
595 |
+
"ename": "",
|
596 |
+
"evalue": "",
|
597 |
+
"output_type": "error",
|
598 |
+
"traceback": [
|
599 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
600 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
601 |
+
]
|
602 |
+
},
|
603 |
+
{
|
604 |
+
"ename": "",
|
605 |
+
"evalue": "",
|
606 |
+
"output_type": "error",
|
607 |
+
"traceback": [
|
608 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
609 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
610 |
+
]
|
611 |
+
}
|
612 |
+
],
|
613 |
+
"source": [
|
614 |
+
"import json\n",
|
615 |
+
"\n",
|
616 |
+
"# file count\n",
|
617 |
+
"i = 0\n",
|
618 |
+
"data = []\n",
|
619 |
+
"\n",
|
620 |
+
"# iterate over all json files\n",
|
621 |
+
"for path in paths:\n",
|
622 |
+
" print(path)\n",
|
623 |
+
" # read file and store as list (this requires much memory, as the files are huge)\n",
|
624 |
+
" with open(path, 'r') as json_file:\n",
|
625 |
+
" json_list = list(json_file)\n",
|
626 |
+
" \n",
|
627 |
+
" # process every context, question, answer pair\n",
|
628 |
+
" for json_str in json_list:\n",
|
629 |
+
" result = json.loads(json_str)\n",
|
630 |
+
"\n",
|
631 |
+
" # append a question mark - SQuAD questions end with a qm too\n",
|
632 |
+
" question = result['question_text'] + \"?\"\n",
|
633 |
+
" \n",
|
634 |
+
" # some question do not contain an answer - we do not need them\n",
|
635 |
+
" if(len(result['annotations'][0]['short_answers'])==0):\n",
|
636 |
+
" continue\n",
|
637 |
+
"\n",
|
638 |
+
" # get true start/end byte\n",
|
639 |
+
" true_start = result['annotations'][0]['short_answers'][0]['start_byte']\n",
|
640 |
+
" true_end = result['annotations'][0]['short_answers'][0]['end_byte']\n",
|
641 |
+
"\n",
|
642 |
+
" # convert to bytes\n",
|
643 |
+
" byte_encoding = bytes(result['document_html'], encoding='utf-8')\n",
|
644 |
+
" \n",
|
645 |
+
" # the document is the whole wikipedia article, we randomly choose an appropriate part (containing the\n",
|
646 |
+
" # answer): we have 512 tokens as the input for the model - 4000 bytes lead to a good length\n",
|
647 |
+
" max_back = 3500 if true_start >= 3500 else true_start\n",
|
648 |
+
" first = random.randint(int(true_start)-max_back, int(true_start))\n",
|
649 |
+
" end = first + 3500 + true_end - true_start\n",
|
650 |
+
" \n",
|
651 |
+
" # get chosen context\n",
|
652 |
+
" cleanbytes = byte_encoding[first:end]\n",
|
653 |
+
" # decode back to text - if our end byte is the middle of a word, we ignore it and cut it off\n",
|
654 |
+
" cleantext = bytes.decode(cleanbytes, errors='ignore')\n",
|
655 |
+
" # clean html tags\n",
|
656 |
+
" cleantext = cleanhtml(cleantext)\n",
|
657 |
+
"\n",
|
658 |
+
" # find the true answer\n",
|
659 |
+
" answer_start = cleanbytes.find(byte_encoding[true_start:true_end])\n",
|
660 |
+
" true_answer = bytes.decode(cleanbytes[answer_start:answer_start+(true_end-true_start)])\n",
|
661 |
+
" \n",
|
662 |
+
" # clean html tags\n",
|
663 |
+
" true_answer = cleanhtml(true_answer)\n",
|
664 |
+
" \n",
|
665 |
+
" start_ind = cleantext.find(true_answer)\n",
|
666 |
+
" \n",
|
667 |
+
" # If cleaning the string makes the answer not findable skip it\n",
|
668 |
+
" # this hardly ever happens, except if there is an emense amount of web artifacts\n",
|
669 |
+
" if start_ind == -1:\n",
|
670 |
+
" continue\n",
|
671 |
+
" \n",
|
672 |
+
" data.append([cleantext, question, true_answer, str(start_ind)])\n",
|
673 |
+
"\n",
|
674 |
+
" if len(data) == 1000:\n",
|
675 |
+
" with open(f\"data/natural_questions_train/text_{i}.txt\", 'w', encoding='utf-8') as f:\n",
|
676 |
+
" f.write(\"\\n\".join([\"\\t\".join(t) for t in data]))\n",
|
677 |
+
" i += 1\n",
|
678 |
+
" data = []\n",
|
679 |
+
"with open(f\"data/natural_questions_train/text_{i}.txt\", 'w', encoding='utf-8') as f:\n",
|
680 |
+
" f.write(\"\\n\".join([\"\\t\".join(t) for t in data]))"
|
681 |
+
]
|
682 |
+
},
|
683 |
+
{
|
684 |
+
"cell_type": "markdown",
|
685 |
+
"id": "30f26b4e",
|
686 |
+
"metadata": {},
|
687 |
+
"source": [
|
688 |
+
"### Testing\n",
|
689 |
+
"In the following, we first check if the shape of the file is correct.\n",
|
690 |
+
"\n",
|
691 |
+
"Then we iterate over the file and check if the answers according to the file are the same as in the original file."
|
692 |
+
]
|
693 |
+
},
|
694 |
+
{
|
695 |
+
"cell_type": "code",
|
696 |
+
"execution_count": null,
|
697 |
+
"id": "490ac0db",
|
698 |
+
"metadata": {},
|
699 |
+
"outputs": [
|
700 |
+
{
|
701 |
+
"ename": "",
|
702 |
+
"evalue": "",
|
703 |
+
"output_type": "error",
|
704 |
+
"traceback": [
|
705 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
706 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
707 |
+
]
|
708 |
+
},
|
709 |
+
{
|
710 |
+
"ename": "",
|
711 |
+
"evalue": "",
|
712 |
+
"output_type": "error",
|
713 |
+
"traceback": [
|
714 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
715 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
716 |
+
]
|
717 |
+
}
|
718 |
+
],
|
719 |
+
"source": [
|
720 |
+
"with open(\"data/natural_questions_train/text_0.txt\", 'r', encoding='utf-8') as f:\n",
|
721 |
+
" lines = f.read().split('\\n')\n",
|
722 |
+
" \n",
|
723 |
+
"lines = pd.DataFrame([line.split(\"\\t\") for line in lines], columns=[\"context\", \"question\", \"answer\", \"answer_start\"])"
|
724 |
+
]
|
725 |
+
},
|
726 |
+
{
|
727 |
+
"cell_type": "code",
|
728 |
+
"execution_count": null,
|
729 |
+
"id": "0d7cc3ee",
|
730 |
+
"metadata": {},
|
731 |
+
"outputs": [
|
732 |
+
{
|
733 |
+
"ename": "",
|
734 |
+
"evalue": "",
|
735 |
+
"output_type": "error",
|
736 |
+
"traceback": [
|
737 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
738 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
739 |
+
]
|
740 |
+
},
|
741 |
+
{
|
742 |
+
"ename": "",
|
743 |
+
"evalue": "",
|
744 |
+
"output_type": "error",
|
745 |
+
"traceback": [
|
746 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
747 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
748 |
+
]
|
749 |
+
}
|
750 |
+
],
|
751 |
+
"source": [
|
752 |
+
"assert lines.shape == (1000, 4)\n",
|
753 |
+
"print(\"Passed\")"
|
754 |
+
]
|
755 |
+
},
|
756 |
+
{
|
757 |
+
"cell_type": "code",
|
758 |
+
"execution_count": null,
|
759 |
+
"id": "0fd8a854",
|
760 |
+
"metadata": {},
|
761 |
+
"outputs": [
|
762 |
+
{
|
763 |
+
"ename": "",
|
764 |
+
"evalue": "",
|
765 |
+
"output_type": "error",
|
766 |
+
"traceback": [
|
767 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
768 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
769 |
+
]
|
770 |
+
},
|
771 |
+
{
|
772 |
+
"ename": "",
|
773 |
+
"evalue": "",
|
774 |
+
"output_type": "error",
|
775 |
+
"traceback": [
|
776 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
777 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
778 |
+
]
|
779 |
+
}
|
780 |
+
],
|
781 |
+
"source": [
|
782 |
+
"with open(\"data/natural_questions/v1.0/train/nq-train-00.jsonl\", 'r') as json_file:\n",
|
783 |
+
" json_list = list(json_file)[:500]\n",
|
784 |
+
"del json_file"
|
785 |
+
]
|
786 |
+
},
|
787 |
+
{
|
788 |
+
"cell_type": "code",
|
789 |
+
"execution_count": null,
|
790 |
+
"id": "170bff30",
|
791 |
+
"metadata": {},
|
792 |
+
"outputs": [
|
793 |
+
{
|
794 |
+
"ename": "",
|
795 |
+
"evalue": "",
|
796 |
+
"output_type": "error",
|
797 |
+
"traceback": [
|
798 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
799 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
800 |
+
]
|
801 |
+
},
|
802 |
+
{
|
803 |
+
"ename": "",
|
804 |
+
"evalue": "",
|
805 |
+
"output_type": "error",
|
806 |
+
"traceback": [
|
807 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
808 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
809 |
+
]
|
810 |
+
}
|
811 |
+
],
|
812 |
+
"source": [
|
813 |
+
"lines_index = 0\n",
|
814 |
+
"for i in range(len(json_list)):\n",
|
815 |
+
" result = json.loads(json_list[i])\n",
|
816 |
+
" \n",
|
817 |
+
" if(len(result['annotations'][0]['short_answers'])==0):\n",
|
818 |
+
" pass\n",
|
819 |
+
" else: \n",
|
820 |
+
" # assert that the question text is the same\n",
|
821 |
+
" assert result['question_text'] + \"?\" == lines.loc[lines_index, 'question']\n",
|
822 |
+
" true_start = result['annotations'][0]['short_answers'][0]['start_byte']\n",
|
823 |
+
" true_end = result['annotations'][0]['short_answers'][0]['end_byte']\n",
|
824 |
+
" true_answer = bytes.decode(bytes(result['document_html'], encoding='utf-8')[true_start:true_end])\n",
|
825 |
+
" \n",
|
826 |
+
" processed_answer = lines.loc[lines_index, 'answer']\n",
|
827 |
+
" # assert that the answer is the same\n",
|
828 |
+
" assert cleanhtml(true_answer) == processed_answer\n",
|
829 |
+
" \n",
|
830 |
+
" start_ind = int(lines.loc[lines_index, 'answer_start'])\n",
|
831 |
+
" # assert that the answer (according to the index) is the same\n",
|
832 |
+
" assert cleanhtml(true_answer) == lines.loc[lines_index, 'context'][start_ind:start_ind+len(processed_answer)]\n",
|
833 |
+
" \n",
|
834 |
+
" lines_index += 1\n",
|
835 |
+
" \n",
|
836 |
+
" if lines_index == len(lines):\n",
|
837 |
+
" break\n",
|
838 |
+
"print(\"Passed\")"
|
839 |
+
]
|
840 |
+
},
|
841 |
+
{
|
842 |
+
"cell_type": "markdown",
|
843 |
+
"id": "78e6e737",
|
844 |
+
"metadata": {},
|
845 |
+
"source": [
|
846 |
+
"## Hotpot QA"
|
847 |
+
]
|
848 |
+
},
|
849 |
+
{
|
850 |
+
"cell_type": "code",
|
851 |
+
"execution_count": null,
|
852 |
+
"id": "27efcc8c",
|
853 |
+
"metadata": {},
|
854 |
+
"outputs": [
|
855 |
+
{
|
856 |
+
"ename": "",
|
857 |
+
"evalue": "",
|
858 |
+
"output_type": "error",
|
859 |
+
"traceback": [
|
860 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
861 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
862 |
+
]
|
863 |
+
},
|
864 |
+
{
|
865 |
+
"ename": "",
|
866 |
+
"evalue": "",
|
867 |
+
"output_type": "error",
|
868 |
+
"traceback": [
|
869 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
870 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
871 |
+
]
|
872 |
+
}
|
873 |
+
],
|
874 |
+
"source": [
|
875 |
+
"ds = load_dataset(\"hotpot_qa\", 'fullwiki')"
|
876 |
+
]
|
877 |
+
},
|
878 |
+
{
|
879 |
+
"cell_type": "code",
|
880 |
+
"execution_count": null,
|
881 |
+
"id": "1493f21f",
|
882 |
+
"metadata": {},
|
883 |
+
"outputs": [
|
884 |
+
{
|
885 |
+
"ename": "",
|
886 |
+
"evalue": "",
|
887 |
+
"output_type": "error",
|
888 |
+
"traceback": [
|
889 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
890 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
891 |
+
]
|
892 |
+
},
|
893 |
+
{
|
894 |
+
"ename": "",
|
895 |
+
"evalue": "",
|
896 |
+
"output_type": "error",
|
897 |
+
"traceback": [
|
898 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
899 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
900 |
+
]
|
901 |
+
}
|
902 |
+
],
|
903 |
+
"source": [
|
904 |
+
"ds"
|
905 |
+
]
|
906 |
+
},
|
907 |
+
{
|
908 |
+
"cell_type": "code",
|
909 |
+
"execution_count": null,
|
910 |
+
"id": "2a047946",
|
911 |
+
"metadata": {},
|
912 |
+
"outputs": [
|
913 |
+
{
|
914 |
+
"ename": "",
|
915 |
+
"evalue": "",
|
916 |
+
"output_type": "error",
|
917 |
+
"traceback": [
|
918 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
919 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
920 |
+
]
|
921 |
+
},
|
922 |
+
{
|
923 |
+
"ename": "",
|
924 |
+
"evalue": "",
|
925 |
+
"output_type": "error",
|
926 |
+
"traceback": [
|
927 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
928 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
929 |
+
]
|
930 |
+
}
|
931 |
+
],
|
932 |
+
"source": [
|
933 |
+
"os.mkdir('data/hotpotqa_training')\n",
|
934 |
+
"os.mkdir('data/hotpotqa_test')"
|
935 |
+
]
|
936 |
+
},
|
937 |
+
{
|
938 |
+
"cell_type": "code",
|
939 |
+
"execution_count": null,
|
940 |
+
"id": "e65b6485",
|
941 |
+
"metadata": {},
|
942 |
+
"outputs": [
|
943 |
+
{
|
944 |
+
"ename": "",
|
945 |
+
"evalue": "",
|
946 |
+
"output_type": "error",
|
947 |
+
"traceback": [
|
948 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
949 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
950 |
+
]
|
951 |
+
},
|
952 |
+
{
|
953 |
+
"ename": "",
|
954 |
+
"evalue": "",
|
955 |
+
"output_type": "error",
|
956 |
+
"traceback": [
|
957 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
958 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
959 |
+
]
|
960 |
+
}
|
961 |
+
],
|
962 |
+
"source": [
|
963 |
+
"# column contains the split (either train or validation), save_dir is the directory\n",
|
964 |
+
"def save_samples(column, save_dir):\n",
|
965 |
+
" text = []\n",
|
966 |
+
" i = 0\n",
|
967 |
+
"\n",
|
968 |
+
" for sample in tqdm(ds[column]):\n",
|
969 |
+
" \n",
|
970 |
+
" # preprocess the context and question by removing the newlines\n",
|
971 |
+
" context = sample['context']['sentences']\n",
|
972 |
+
" context = \" \".join([\"\".join(sentence) for sentence in context])\n",
|
973 |
+
" question = sample['question'].replace('\\n','')\n",
|
974 |
+
" \n",
|
975 |
+
" # get the answer as text and start character index\n",
|
976 |
+
" answer_text = sample['answer']\n",
|
977 |
+
" answer_start = context.find(answer_text)\n",
|
978 |
+
" if answer_start == -1:\n",
|
979 |
+
" continue\n",
|
980 |
+
" \n",
|
981 |
+
" \n",
|
982 |
+
" \n",
|
983 |
+
" if answer_start > 1500:\n",
|
984 |
+
" first = random.randint(answer_start-1500, answer_start)\n",
|
985 |
+
" end = first + 1500 + len(answer_text)\n",
|
986 |
+
" \n",
|
987 |
+
" context = context[first:end+1]\n",
|
988 |
+
" answer_start = context.find(answer_text)\n",
|
989 |
+
" \n",
|
990 |
+
" if answer_start == -1:continue\n",
|
991 |
+
" \n",
|
992 |
+
" text.append([context, question, answer_text, str(answer_start)])\n",
|
993 |
+
"\n",
|
994 |
+
" # we choose chunks of 1000\n",
|
995 |
+
" if len(text) == 1000:\n",
|
996 |
+
" with open(f\"data/{save_dir}/text_{i}.txt\", 'w', encoding='utf-8') as f:\n",
|
997 |
+
" f.write(\"\\n\".join([\"\\t\".join(t) for t in text]))\n",
|
998 |
+
" text = []\n",
|
999 |
+
" i += 1\n",
|
1000 |
+
"\n",
|
1001 |
+
" # save remaining\n",
|
1002 |
+
" with open(f\"data/{save_dir}/text_{i}.txt\", 'w', encoding='utf-8') as f:\n",
|
1003 |
+
" f.write(\"\\n\".join([\"\\t\".join(t) for t in text]))\n",
|
1004 |
+
"\n",
|
1005 |
+
"save_samples(\"train\", \"hotpotqa_training\")\n",
|
1006 |
+
"save_samples(\"validation\", \"hotpotqa_test\")"
|
1007 |
+
]
|
1008 |
+
},
|
1009 |
+
{
|
1010 |
+
"cell_type": "markdown",
|
1011 |
+
"id": "97cc358f",
|
1012 |
+
"metadata": {},
|
1013 |
+
"source": [
|
1014 |
+
"## Testing"
|
1015 |
+
]
|
1016 |
+
},
|
1017 |
+
{
|
1018 |
+
"cell_type": "code",
|
1019 |
+
"execution_count": null,
|
1020 |
+
"id": "f321483c",
|
1021 |
+
"metadata": {},
|
1022 |
+
"outputs": [
|
1023 |
+
{
|
1024 |
+
"ename": "",
|
1025 |
+
"evalue": "",
|
1026 |
+
"output_type": "error",
|
1027 |
+
"traceback": [
|
1028 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
1029 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
1030 |
+
]
|
1031 |
+
},
|
1032 |
+
{
|
1033 |
+
"ename": "",
|
1034 |
+
"evalue": "",
|
1035 |
+
"output_type": "error",
|
1036 |
+
"traceback": [
|
1037 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
1038 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
1039 |
+
]
|
1040 |
+
}
|
1041 |
+
],
|
1042 |
+
"source": [
|
1043 |
+
"with open(\"data/hotpotqa_training/text_0.txt\", 'r', encoding='utf-8') as f:\n",
|
1044 |
+
" lines = f.read().split('\\n')\n",
|
1045 |
+
" \n",
|
1046 |
+
"lines = pd.DataFrame([line.split(\"\\t\") for line in lines], columns=[\"context\", \"question\", \"answer\", \"answer_start\"])"
|
1047 |
+
]
|
1048 |
+
},
|
1049 |
+
{
|
1050 |
+
"cell_type": "code",
|
1051 |
+
"execution_count": null,
|
1052 |
+
"id": "72a96e78",
|
1053 |
+
"metadata": {},
|
1054 |
+
"outputs": [
|
1055 |
+
{
|
1056 |
+
"ename": "",
|
1057 |
+
"evalue": "",
|
1058 |
+
"output_type": "error",
|
1059 |
+
"traceback": [
|
1060 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
1061 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
1062 |
+
]
|
1063 |
+
},
|
1064 |
+
{
|
1065 |
+
"ename": "",
|
1066 |
+
"evalue": "",
|
1067 |
+
"output_type": "error",
|
1068 |
+
"traceback": [
|
1069 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
1070 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
1071 |
+
]
|
1072 |
+
}
|
1073 |
+
],
|
1074 |
+
"source": [
|
1075 |
+
"assert lines.shape == (1000, 4)\n",
|
1076 |
+
"print(\"Passed\")"
|
1077 |
+
]
|
1078 |
+
},
|
1079 |
+
{
|
1080 |
+
"cell_type": "code",
|
1081 |
+
"execution_count": null,
|
1082 |
+
"id": "c32c2f16",
|
1083 |
+
"metadata": {},
|
1084 |
+
"outputs": [
|
1085 |
+
{
|
1086 |
+
"ename": "",
|
1087 |
+
"evalue": "",
|
1088 |
+
"output_type": "error",
|
1089 |
+
"traceback": [
|
1090 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
1091 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
1092 |
+
]
|
1093 |
+
},
|
1094 |
+
{
|
1095 |
+
"ename": "",
|
1096 |
+
"evalue": "",
|
1097 |
+
"output_type": "error",
|
1098 |
+
"traceback": [
|
1099 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
1100 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
1101 |
+
]
|
1102 |
+
}
|
1103 |
+
],
|
1104 |
+
"source": [
|
1105 |
+
"# we assert that we have the right interval\n",
|
1106 |
+
"for ind, line in lines.iterrows():\n",
|
1107 |
+
" sample = line\n",
|
1108 |
+
" answer_start = int(sample['answer_start'])\n",
|
1109 |
+
" assert sample['context'][answer_start:answer_start+len(sample['answer'])] == sample['answer']\n",
|
1110 |
+
"print(\"Passed\")"
|
1111 |
+
]
|
1112 |
+
},
|
1113 |
+
{
|
1114 |
+
"cell_type": "code",
|
1115 |
+
"execution_count": null,
|
1116 |
+
"id": "bc36fe7d",
|
1117 |
+
"metadata": {},
|
1118 |
+
"outputs": [
|
1119 |
+
{
|
1120 |
+
"ename": "",
|
1121 |
+
"evalue": "",
|
1122 |
+
"output_type": "error",
|
1123 |
+
"traceback": [
|
1124 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
1125 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
1126 |
+
]
|
1127 |
+
},
|
1128 |
+
{
|
1129 |
+
"ename": "",
|
1130 |
+
"evalue": "",
|
1131 |
+
"output_type": "error",
|
1132 |
+
"traceback": [
|
1133 |
+
"\u001b[1;31mnotebook controller is DISPOSED. \n",
|
1134 |
+
"\u001b[1;31mView Jupyter <a href='command:jupyter.viewOutput'>log</a> for further details."
|
1135 |
+
]
|
1136 |
+
}
|
1137 |
+
],
|
1138 |
+
"source": []
|
1139 |
+
}
|
1140 |
+
],
|
1141 |
+
"metadata": {
|
1142 |
+
"kernelspec": {
|
1143 |
+
"display_name": "Python 3 (ipykernel)",
|
1144 |
+
"language": "python",
|
1145 |
+
"name": "python3"
|
1146 |
+
},
|
1147 |
+
"language_info": {
|
1148 |
+
"codemirror_mode": {
|
1149 |
+
"name": "ipython",
|
1150 |
+
"version": 3
|
1151 |
+
},
|
1152 |
+
"file_extension": ".py",
|
1153 |
+
"mimetype": "text/x-python",
|
1154 |
+
"name": "python",
|
1155 |
+
"nbconvert_exporter": "python",
|
1156 |
+
"pygments_lexer": "ipython3",
|
1157 |
+
"version": "3.10.16"
|
1158 |
+
},
|
1159 |
+
"toc": {
|
1160 |
+
"base_numbering": 1,
|
1161 |
+
"nav_menu": {},
|
1162 |
+
"number_sections": true,
|
1163 |
+
"sideBar": true,
|
1164 |
+
"skip_h1_title": false,
|
1165 |
+
"title_cell": "Table of Contents",
|
1166 |
+
"title_sidebar": "Contents",
|
1167 |
+
"toc_cell": false,
|
1168 |
+
"toc_position": {},
|
1169 |
+
"toc_section_display": true,
|
1170 |
+
"toc_window_display": false
|
1171 |
+
},
|
1172 |
+
"varInspector": {
|
1173 |
+
"cols": {
|
1174 |
+
"lenName": 16,
|
1175 |
+
"lenType": 16,
|
1176 |
+
"lenVar": 40
|
1177 |
+
},
|
1178 |
+
"kernels_config": {
|
1179 |
+
"python": {
|
1180 |
+
"delete_cmd_postfix": "",
|
1181 |
+
"delete_cmd_prefix": "del ",
|
1182 |
+
"library": "var_list.py",
|
1183 |
+
"varRefreshCmd": "print(var_dic_list())"
|
1184 |
+
},
|
1185 |
+
"r": {
|
1186 |
+
"delete_cmd_postfix": ") ",
|
1187 |
+
"delete_cmd_prefix": "rm(",
|
1188 |
+
"library": "var_list.r",
|
1189 |
+
"varRefreshCmd": "cat(var_dic_list()) "
|
1190 |
+
}
|
1191 |
+
},
|
1192 |
+
"types_to_exclude": [
|
1193 |
+
"module",
|
1194 |
+
"function",
|
1195 |
+
"builtin_function_or_method",
|
1196 |
+
"instance",
|
1197 |
+
"_Feature"
|
1198 |
+
],
|
1199 |
+
"window_display": false
|
1200 |
+
},
|
1201 |
+
"vscode": {
|
1202 |
+
"interpreter": {
|
1203 |
+
"hash": "85bf9c14e9ba73b783ed1274d522bec79eb0b2b739090180d8ce17bb11aff4aa"
|
1204 |
+
}
|
1205 |
+
}
|
1206 |
+
},
|
1207 |
+
"nbformat": 4,
|
1208 |
+
"nbformat_minor": 5
|
1209 |
+
}
|
qa_model.py
ADDED
@@ -0,0 +1,532 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch import nn
|
2 |
+
import torch
|
3 |
+
from typing import Optional
|
4 |
+
import copy
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
"""
|
8 |
+
This module contains the implementation of the QA model. We define three different models and a dataset class.
|
9 |
+
The structure is based on the Hugging Face implementations.
|
10 |
+
https://huggingface.co/docs/transformers/model_doc/distilbert
|
11 |
+
"""
|
12 |
+
|
13 |
+
class SimpleQuestionDistilBERT(nn.Module):
|
14 |
+
"""
|
15 |
+
This class implements a simple version of the distilbert question answering model, following the implementation of Hugging Face,
|
16 |
+
https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/distilbert/modeling_distilbert.py#L805
|
17 |
+
|
18 |
+
It basically fine-tunes a given distilbert model. We only add one linear layer on top, which determines the start and end logits.
|
19 |
+
"""
|
20 |
+
def __init__(self, distilbert, dropout=0.1):
|
21 |
+
"""
|
22 |
+
Creates and initialises model
|
23 |
+
"""
|
24 |
+
super(SimpleQuestionDistilBERT, self).__init__()
|
25 |
+
|
26 |
+
self.distilbert = distilbert
|
27 |
+
|
28 |
+
self.dropout = nn.Dropout(dropout)
|
29 |
+
|
30 |
+
self.classifier = nn.Linear(768, 2)
|
31 |
+
|
32 |
+
# initialise weights
|
33 |
+
def init_weights(m):
|
34 |
+
if isinstance(m, nn.Linear):
|
35 |
+
nn.init.xavier_uniform_(m.weight)
|
36 |
+
m.bias.data.fill_(0.01)
|
37 |
+
self.classifier.apply(init_weights)
|
38 |
+
|
39 |
+
|
40 |
+
def forward(self,
|
41 |
+
input_ids: Optional[torch.Tensor] = None,
|
42 |
+
attention_mask: Optional[torch.Tensor] = None,
|
43 |
+
head_mask: Optional[torch.Tensor] = None,
|
44 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
45 |
+
start_positions: Optional[torch.Tensor] = None,
|
46 |
+
end_positions: Optional[torch.Tensor] = None,
|
47 |
+
output_attentions: Optional[bool] = None,
|
48 |
+
output_hidden_states: Optional[bool] = None,
|
49 |
+
return_dict: Optional[bool] = None):
|
50 |
+
"""
|
51 |
+
This function implements the forward pass of the model. It takes the input_ids and attention_mask and returns the start and end logits.
|
52 |
+
"""
|
53 |
+
# make predictions on base model
|
54 |
+
distilbert_output = self.distilbert(
|
55 |
+
input_ids=input_ids,
|
56 |
+
attention_mask=attention_mask,
|
57 |
+
inputs_embeds=inputs_embeds,
|
58 |
+
output_attentions=output_attentions,
|
59 |
+
output_hidden_states=output_hidden_states,
|
60 |
+
return_dict=return_dict,
|
61 |
+
)
|
62 |
+
|
63 |
+
# retrieve hidden states
|
64 |
+
hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
|
65 |
+
hidden_states = self.dropout(hidden_states)
|
66 |
+
|
67 |
+
# make predictions on head
|
68 |
+
logits = self.classifier(hidden_states)
|
69 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
70 |
+
start_logits = start_logits.squeeze(-1).contiguous() # (bs, max_query_len)
|
71 |
+
end_logits = end_logits.squeeze(-1).contiguous() # (bs, max_query_len)
|
72 |
+
|
73 |
+
# calculate loss
|
74 |
+
total_loss = None
|
75 |
+
if start_positions is not None and end_positions is not None:
|
76 |
+
if len(start_positions.size()) > 1:
|
77 |
+
start_positions = start_positions.squeeze(-1)
|
78 |
+
if len(end_positions.size()) > 1:
|
79 |
+
end_positions = end_positions.squeeze(-1)
|
80 |
+
|
81 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
82 |
+
ignored_index = start_logits.size(1)
|
83 |
+
start_positions = start_positions.clamp(0, ignored_index)
|
84 |
+
end_positions = end_positions.clamp(0, ignored_index)
|
85 |
+
|
86 |
+
loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
|
87 |
+
start_loss = loss_fct(start_logits, start_positions)
|
88 |
+
end_loss = loss_fct(end_logits, end_positions)
|
89 |
+
total_loss = (start_loss + end_loss) / 2
|
90 |
+
|
91 |
+
return {"loss": total_loss,
|
92 |
+
"start_logits": start_logits,
|
93 |
+
"end_logits": end_logits,
|
94 |
+
"hidden_states": distilbert_output.hidden_states,
|
95 |
+
"attentions": distilbert_output.attentions}
|
96 |
+
|
97 |
+
|
98 |
+
class QuestionDistilBERT(nn.Module):
|
99 |
+
"""
|
100 |
+
This class implements the distilbert question answering model. We fix all layers of the base model and only fine-tune the head.
|
101 |
+
The head consists of a transformer encoder with three layers and a classifier on top.
|
102 |
+
"""
|
103 |
+
def __init__(self, distilbert, dropout=0.1):
|
104 |
+
"""
|
105 |
+
Creates and initialises QuestionDIstilBERT instance
|
106 |
+
"""
|
107 |
+
super(QuestionDistilBERT, self).__init__()
|
108 |
+
|
109 |
+
# fix parameters for base model
|
110 |
+
for param in distilbert.parameters():
|
111 |
+
param.requires_grad = False
|
112 |
+
|
113 |
+
self.distilbert = distilbert
|
114 |
+
self.relu = nn.ReLU()
|
115 |
+
|
116 |
+
self.dropout = nn.Dropout(dropout)
|
117 |
+
self.te = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=768, nhead=12), num_layers=3)
|
118 |
+
|
119 |
+
# create custom head
|
120 |
+
self.classifier = nn.Sequential(
|
121 |
+
nn.Dropout(dropout),
|
122 |
+
nn.ReLU(),
|
123 |
+
nn.Linear(768, 512),
|
124 |
+
nn.Dropout(dropout),
|
125 |
+
nn.ReLU(),
|
126 |
+
nn.Linear(512, 256),
|
127 |
+
nn.Dropout(dropout),
|
128 |
+
nn.ReLU(),
|
129 |
+
nn.Linear(256, 128),
|
130 |
+
nn.Dropout(dropout),
|
131 |
+
nn.ReLU(),
|
132 |
+
nn.Linear(128, 64),
|
133 |
+
nn.Dropout(dropout),
|
134 |
+
nn.ReLU(),
|
135 |
+
nn.Linear(64, 2)
|
136 |
+
)
|
137 |
+
|
138 |
+
# initialise weights of the linear layers
|
139 |
+
def init_weights(m):
|
140 |
+
if isinstance(m, nn.Linear):
|
141 |
+
nn.init.xavier_uniform_(m.weight)
|
142 |
+
m.bias.data.fill_(0.01)
|
143 |
+
|
144 |
+
self.classifier.apply(init_weights)
|
145 |
+
|
146 |
+
def forward(self,
|
147 |
+
input_ids: Optional[torch.Tensor] = None,
|
148 |
+
attention_mask: Optional[torch.Tensor] = None,
|
149 |
+
head_mask: Optional[torch.Tensor] = None,
|
150 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
151 |
+
start_positions: Optional[torch.Tensor] = None,
|
152 |
+
end_positions: Optional[torch.Tensor] = None,
|
153 |
+
output_attentions: Optional[bool] = None,
|
154 |
+
output_hidden_states: Optional[bool] = None,
|
155 |
+
return_dict: Optional[bool] = None):
|
156 |
+
"""
|
157 |
+
This function implements the forward pass of the model. It takes the input_ids and attention_mask and returns the start and end logits.
|
158 |
+
"""
|
159 |
+
# make predictions on base model
|
160 |
+
distilbert_output = self.distilbert(
|
161 |
+
input_ids=input_ids,
|
162 |
+
attention_mask=attention_mask,
|
163 |
+
inputs_embeds=inputs_embeds,
|
164 |
+
output_attentions=output_attentions,
|
165 |
+
output_hidden_states=output_hidden_states,
|
166 |
+
return_dict=return_dict,
|
167 |
+
)
|
168 |
+
|
169 |
+
# retrieve hidden states
|
170 |
+
hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
|
171 |
+
hidden_states = self.dropout(hidden_states)
|
172 |
+
attn_output = self.te(hidden_states)
|
173 |
+
|
174 |
+
# make predictions on head
|
175 |
+
logits = self.classifier(attn_output)
|
176 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
177 |
+
start_logits = start_logits.squeeze(-1).contiguous()
|
178 |
+
end_logits = end_logits.squeeze(-1).contiguous()
|
179 |
+
|
180 |
+
# calculate loss
|
181 |
+
total_loss = None
|
182 |
+
if start_positions is not None and end_positions is not None:
|
183 |
+
if len(start_positions.size()) > 1:
|
184 |
+
start_positions = start_positions.squeeze(-1)
|
185 |
+
if len(end_positions.size()) > 1:
|
186 |
+
end_positions = end_positions.squeeze(-1)
|
187 |
+
|
188 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
189 |
+
ignored_index = start_logits.size(1)
|
190 |
+
start_positions = start_positions.clamp(0, ignored_index)
|
191 |
+
end_positions = end_positions.clamp(0, ignored_index)
|
192 |
+
|
193 |
+
loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
|
194 |
+
start_loss = loss_fct(start_logits, start_positions)
|
195 |
+
end_loss = loss_fct(end_logits, end_positions)
|
196 |
+
total_loss = (start_loss + end_loss) / 2
|
197 |
+
|
198 |
+
return {"loss": total_loss,
|
199 |
+
"start_logits": start_logits,
|
200 |
+
"end_logits": end_logits,
|
201 |
+
"hidden_states": distilbert_output.hidden_states,
|
202 |
+
"attentions": distilbert_output.attentions}
|
203 |
+
|
204 |
+
|
205 |
+
class ReuseQuestionDistilBERT(nn.Module):
|
206 |
+
"""
|
207 |
+
This class imports a model where all layers of the base distilbert model are fixed.
|
208 |
+
Instead of training a completely new head, we copy the last two layers of the base model and add a classifier on top.
|
209 |
+
"""
|
210 |
+
def __init__(self, distilbert, dropout=0.15):
|
211 |
+
"""
|
212 |
+
Creates and initialises QuestionDIstilBERT instance
|
213 |
+
"""
|
214 |
+
super(ReuseQuestionDistilBERT, self).__init__()
|
215 |
+
self.te = copy.deepcopy(list(list(distilbert.children())[1].children())[0][-2:])
|
216 |
+
# fix parameters for base model
|
217 |
+
for param in distilbert.parameters():
|
218 |
+
param.requires_grad = False
|
219 |
+
|
220 |
+
self.distilbert = distilbert
|
221 |
+
self.relu = nn.ReLU()
|
222 |
+
|
223 |
+
self.dropout = nn.Dropout(dropout)
|
224 |
+
|
225 |
+
# create custom head
|
226 |
+
self.classifier = nn.Linear(768, 2)
|
227 |
+
|
228 |
+
def init_weights(m):
|
229 |
+
if isinstance(m, nn.Linear):
|
230 |
+
nn.init.xavier_uniform_(m.weight)
|
231 |
+
m.bias.data.fill_(0.01)
|
232 |
+
self.classifier.apply(init_weights)
|
233 |
+
|
234 |
+
def forward(self,
|
235 |
+
input_ids: Optional[torch.Tensor] = None,
|
236 |
+
attention_mask: Optional[torch.Tensor] = None,
|
237 |
+
head_mask: Optional[torch.Tensor] = None,
|
238 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
239 |
+
start_positions: Optional[torch.Tensor] = None,
|
240 |
+
end_positions: Optional[torch.Tensor] = None,
|
241 |
+
output_attentions: Optional[bool] = None,
|
242 |
+
output_hidden_states: Optional[bool] = None,
|
243 |
+
return_dict: Optional[bool] = None):
|
244 |
+
"""
|
245 |
+
This function implements the forward pass of the model. It takes the input_ids and attention_mask and returns the start and end logits.
|
246 |
+
"""
|
247 |
+
# make predictions on base model
|
248 |
+
distilbert_output = self.distilbert(
|
249 |
+
input_ids=input_ids,
|
250 |
+
attention_mask=attention_mask,
|
251 |
+
head_mask=head_mask,
|
252 |
+
inputs_embeds=inputs_embeds,
|
253 |
+
output_attentions=output_attentions,
|
254 |
+
output_hidden_states=output_hidden_states,
|
255 |
+
return_dict=return_dict,
|
256 |
+
)
|
257 |
+
|
258 |
+
# retrieve hidden states
|
259 |
+
hidden_states = distilbert_output[0] # (bs, max_query_len, dim)
|
260 |
+
hidden_states = self.dropout(hidden_states)
|
261 |
+
for te in self.te:
|
262 |
+
hidden_states = te(
|
263 |
+
x=hidden_states,
|
264 |
+
attn_mask=attention_mask,
|
265 |
+
head_mask=head_mask,
|
266 |
+
output_attentions=output_attentions
|
267 |
+
)[0]
|
268 |
+
hidden_states = self.dropout(hidden_states)
|
269 |
+
|
270 |
+
# make predictions on head
|
271 |
+
logits = self.classifier(hidden_states)
|
272 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
273 |
+
start_logits = start_logits.squeeze(-1).contiguous() # (bs, max_query_len)
|
274 |
+
end_logits = end_logits.squeeze(-1).contiguous() # (bs, max_query_len)
|
275 |
+
|
276 |
+
# calculate loss
|
277 |
+
total_loss = None
|
278 |
+
if start_positions is not None and end_positions is not None:
|
279 |
+
if len(start_positions.size()) > 1:
|
280 |
+
start_positions = start_positions.squeeze(-1)
|
281 |
+
if len(end_positions.size()) > 1:
|
282 |
+
end_positions = end_positions.squeeze(-1)
|
283 |
+
|
284 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
285 |
+
ignored_index = start_logits.size(1)
|
286 |
+
start_positions = start_positions.clamp(0, ignored_index)
|
287 |
+
end_positions = end_positions.clamp(0, ignored_index)
|
288 |
+
|
289 |
+
loss_fct = nn.CrossEntropyLoss(ignore_index=ignored_index)
|
290 |
+
start_loss = loss_fct(start_logits, start_positions)
|
291 |
+
end_loss = loss_fct(end_logits, end_positions)
|
292 |
+
total_loss = (start_loss + end_loss) / 2
|
293 |
+
|
294 |
+
return {"loss": total_loss,
|
295 |
+
"start_logits": start_logits,
|
296 |
+
"end_logits": end_logits,
|
297 |
+
"hidden_states": distilbert_output.hidden_states,
|
298 |
+
"attentions": distilbert_output.attentions}
|
299 |
+
|
300 |
+
class Dataset(torch.utils.data.Dataset):
|
301 |
+
"""
|
302 |
+
This class creates a dataset for the DistilBERT qa-model.
|
303 |
+
"""
|
304 |
+
def __init__(self, squad_paths, natural_question_paths, hotpotqa_paths, tokenizer):
|
305 |
+
"""
|
306 |
+
creates and initialises dataset object
|
307 |
+
"""
|
308 |
+
self.paths = []
|
309 |
+
self.count = 0
|
310 |
+
if squad_paths != None:
|
311 |
+
self.paths.extend(squad_paths[:len(squad_paths)-1])
|
312 |
+
if natural_question_paths != None:
|
313 |
+
self.paths.extend(natural_question_paths[:len(natural_question_paths)-1])
|
314 |
+
if hotpotqa_paths != None:
|
315 |
+
self.paths.extend(hotpotqa_paths[:len(hotpotqa_paths)-1])
|
316 |
+
self.data = None
|
317 |
+
self.current_file = 0
|
318 |
+
self.remaining = 0
|
319 |
+
self.encodings = None
|
320 |
+
# tokenizer for strings
|
321 |
+
self.tokenizer = tokenizer
|
322 |
+
|
323 |
+
|
324 |
+
def __len__(self):
|
325 |
+
"""
|
326 |
+
returns the length of the dataset
|
327 |
+
"""
|
328 |
+
return len(self.paths)*1000
|
329 |
+
|
330 |
+
def read_file(self, path):
|
331 |
+
"""
|
332 |
+
reads the file stored at path
|
333 |
+
"""
|
334 |
+
with open(path, 'r', encoding='utf-8') as f:
|
335 |
+
lines = f.read().split('\n')
|
336 |
+
return lines
|
337 |
+
|
338 |
+
def get_encodings(self):
|
339 |
+
"""
|
340 |
+
returns encoded strings for the model
|
341 |
+
"""
|
342 |
+
# remove leading and ending whitespaces
|
343 |
+
questions = [q.strip() for q in self.data["question"]]
|
344 |
+
context = [q.lower() for q in self.data["context"]]
|
345 |
+
|
346 |
+
# tokenises questions and context. If the context is too long, we truncate it.
|
347 |
+
inputs = self.tokenizer(
|
348 |
+
questions,
|
349 |
+
context,
|
350 |
+
max_length=512,
|
351 |
+
truncation="only_second",
|
352 |
+
return_offsets_mapping=True,
|
353 |
+
padding="max_length",
|
354 |
+
)
|
355 |
+
|
356 |
+
# tuples of integers giving us the original positions
|
357 |
+
offset_mapping = inputs.pop("offset_mapping")
|
358 |
+
|
359 |
+
answers = self.data["answer"]
|
360 |
+
answer_start = self.data["answer_start"]
|
361 |
+
|
362 |
+
# store beginning and end positions
|
363 |
+
start_positions = []
|
364 |
+
end_positions = []
|
365 |
+
|
366 |
+
# iterate through questions
|
367 |
+
for i, offset in enumerate(offset_mapping):
|
368 |
+
|
369 |
+
answer = answers[i]
|
370 |
+
start_char = int(answer_start[i])
|
371 |
+
end_char = start_char + len(answer)
|
372 |
+
|
373 |
+
sequence_ids = inputs.sequence_ids(i)
|
374 |
+
|
375 |
+
# start and end of context based on tokens
|
376 |
+
idx = 0
|
377 |
+
while sequence_ids[idx] != 1:
|
378 |
+
idx += 1
|
379 |
+
|
380 |
+
context_start = idx
|
381 |
+
while sequence_ids[idx] == 1:
|
382 |
+
idx += 1
|
383 |
+
context_end = idx - 1
|
384 |
+
|
385 |
+
# If answer not inside context add (0,0)
|
386 |
+
if offset[context_start][0] > end_char or offset[context_end][1] < start_char:
|
387 |
+
start_positions.append(0)
|
388 |
+
end_positions.append(0)
|
389 |
+
self.count += 1
|
390 |
+
else:
|
391 |
+
# go to first offset position that is smaller than start char
|
392 |
+
idx = context_start
|
393 |
+
while idx <= context_end and offset[idx][0] <= start_char:
|
394 |
+
idx += 1
|
395 |
+
|
396 |
+
start_positions.append(idx - 1)
|
397 |
+
idx = context_end
|
398 |
+
while idx >= context_start and offset[idx][1] >= end_char:
|
399 |
+
idx -= 1
|
400 |
+
end_positions.append(idx + 1)
|
401 |
+
|
402 |
+
# append start and end position to the embeddings
|
403 |
+
inputs["start_positions"] = start_positions
|
404 |
+
inputs["end_positions"] = end_positions
|
405 |
+
# return input_ids, attention mask, start and end positions (GT)
|
406 |
+
return {'input_ids': torch.tensor(inputs['input_ids']),
|
407 |
+
'attention_mask': torch.tensor(inputs['attention_mask']),
|
408 |
+
'start_positions': torch.tensor(inputs['start_positions']),
|
409 |
+
'end_positions': torch.tensor(inputs['end_positions'])}
|
410 |
+
|
411 |
+
def __getitem__(self, i):
|
412 |
+
"""
|
413 |
+
returns encoding of item i
|
414 |
+
"""
|
415 |
+
|
416 |
+
# if we have looked at all items in the file - take next
|
417 |
+
if self.remaining == 0:
|
418 |
+
self.data = self.read_file(self.paths[self.current_file])
|
419 |
+
self.data = pd.DataFrame([line.split("\t") for line in self.data],
|
420 |
+
columns=["context", "question", "answer", "answer_start"])
|
421 |
+
self.current_file += 1
|
422 |
+
self.remaining = len(self.data)
|
423 |
+
self.encodings = self.get_encodings()
|
424 |
+
# if we are at the end of the dataset, start over again
|
425 |
+
if self.current_file == len(self.paths):
|
426 |
+
self.current_file = 0
|
427 |
+
self.remaining -= 1
|
428 |
+
return {key: tensor[i%1000] for key, tensor in self.encodings.items()}
|
429 |
+
|
430 |
+
def test_model(model, optim, test_ds_loader, device):
|
431 |
+
"""
|
432 |
+
This function is used to test the model's functionality, namely if params are not NaN and infinite,
|
433 |
+
not-frozen parameters have to change, frozen ones must not
|
434 |
+
:param model: pytorch model to evaluate
|
435 |
+
:param optim: optimizer
|
436 |
+
:param test_ds_loader: dataloader object
|
437 |
+
:param device: device, the model is on
|
438 |
+
:raises Exception if the model doesn't work as expected
|
439 |
+
"""
|
440 |
+
## Check if non-frozen parameters changed and frozen ones did not
|
441 |
+
|
442 |
+
# get parameters used for tuning and store initial weight
|
443 |
+
params = [np for np in model.named_parameters() if np[1].requires_grad]
|
444 |
+
initial_params = [(name, p.clone()) for (name, p) in params]
|
445 |
+
|
446 |
+
# get frozen parameters and store initial weight
|
447 |
+
params_frozen = [np for np in model.named_parameters() if not np[1].requires_grad]
|
448 |
+
initial_params_frozen = [(name, p.clone()) for (name, p) in params_frozen]
|
449 |
+
|
450 |
+
# perform one iteration
|
451 |
+
optim.zero_grad()
|
452 |
+
batch = next(iter(test_ds_loader))
|
453 |
+
|
454 |
+
input_ids = batch['input_ids'].to(device)
|
455 |
+
attention_mask = batch['attention_mask'].to(device)
|
456 |
+
start_positions = batch['start_positions'].to(device)
|
457 |
+
end_positions = batch['end_positions'].to(device)
|
458 |
+
|
459 |
+
# forward pass and backpropagation
|
460 |
+
outputs = model(input_ids, attention_mask=attention_mask, start_positions=start_positions,
|
461 |
+
end_positions=end_positions)
|
462 |
+
loss = outputs['loss']
|
463 |
+
loss.backward()
|
464 |
+
optim.step()
|
465 |
+
|
466 |
+
# check if variables have changed
|
467 |
+
for (_, p0), (name, p1) in zip(initial_params, params):
|
468 |
+
# check different than initial
|
469 |
+
try:
|
470 |
+
assert not torch.equal(p0.to(device), p1.to(device))
|
471 |
+
except AssertionError:
|
472 |
+
raise Exception(
|
473 |
+
"{var_name} {msg}".format(
|
474 |
+
var_name=name,
|
475 |
+
msg='did not change!'
|
476 |
+
)
|
477 |
+
)
|
478 |
+
# check not NaN
|
479 |
+
try:
|
480 |
+
assert not torch.isnan(p1).byte().any()
|
481 |
+
except AssertionError:
|
482 |
+
raise Exception(
|
483 |
+
"{var_name} {msg}".format(
|
484 |
+
var_name=name,
|
485 |
+
msg='is NaN!'
|
486 |
+
)
|
487 |
+
)
|
488 |
+
# check finite
|
489 |
+
try:
|
490 |
+
assert torch.isfinite(p1).byte().all()
|
491 |
+
except AssertionError:
|
492 |
+
raise Exception(
|
493 |
+
"{var_name} {msg}".format(
|
494 |
+
var_name=name,
|
495 |
+
msg='is Inf!'
|
496 |
+
)
|
497 |
+
)
|
498 |
+
|
499 |
+
# check that frozen weights have not changed
|
500 |
+
for (_, p0), (name, p1) in zip(initial_params_frozen, params_frozen):
|
501 |
+
# should be the same
|
502 |
+
try:
|
503 |
+
assert torch.equal(p0.to(device), p1.to(device))
|
504 |
+
except AssertionError:
|
505 |
+
raise Exception(
|
506 |
+
"{var_name} {msg}".format(
|
507 |
+
var_name=name,
|
508 |
+
msg='changed!'
|
509 |
+
)
|
510 |
+
)
|
511 |
+
# check not NaN
|
512 |
+
try:
|
513 |
+
assert not torch.isnan(p1).byte().any()
|
514 |
+
except AssertionError:
|
515 |
+
raise Exception(
|
516 |
+
"{var_name} {msg}".format(
|
517 |
+
var_name=name,
|
518 |
+
msg='is NaN!'
|
519 |
+
)
|
520 |
+
)
|
521 |
+
|
522 |
+
# check finite numbers
|
523 |
+
try:
|
524 |
+
assert torch.isfinite(p1).byte().all()
|
525 |
+
except AssertionError:
|
526 |
+
raise Exception(
|
527 |
+
"{var_name} {msg}".format(
|
528 |
+
var_name=name,
|
529 |
+
msg='is Inf!'
|
530 |
+
)
|
531 |
+
)
|
532 |
+
print("Passed")
|
question_answering.ipynb
ADDED
@@ -0,0 +1,2403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"id": "19817716",
|
6 |
+
"metadata": {},
|
7 |
+
"source": [
|
8 |
+
"# Question Answering\n",
|
9 |
+
"The following notebook contains different question answering models. We will start by introducing a representation for the dataset and corresponding DataLoader and then evaluate different models."
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": 50,
|
15 |
+
"id": "49bf46c6",
|
16 |
+
"metadata": {},
|
17 |
+
"outputs": [],
|
18 |
+
"source": [
|
19 |
+
"from transformers import DistilBertModel, DistilBertForMaskedLM, DistilBertConfig, \\\n",
|
20 |
+
" DistilBertTokenizerFast, AutoTokenizer, BertModel, BertForMaskedLM, BertTokenizerFast, BertConfig\n",
|
21 |
+
"from torch import nn\n",
|
22 |
+
"from pathlib import Path\n",
|
23 |
+
"import torch\n",
|
24 |
+
"import pandas as pd\n",
|
25 |
+
"from typing import Optional \n",
|
26 |
+
"from tqdm.auto import tqdm\n",
|
27 |
+
"from util import eval_test_set, count_parameters\n",
|
28 |
+
"from torch.optim import AdamW, RMSprop\n",
|
29 |
+
"\n",
|
30 |
+
"\n",
|
31 |
+
"from qa_model import QuestionDistilBERT, SimpleQuestionDistilBERT, ReuseQuestionDistilBERT, Dataset, test_model"
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"cell_type": "markdown",
|
36 |
+
"id": "3ea47820",
|
37 |
+
"metadata": {},
|
38 |
+
"source": [
|
39 |
+
"## Data\n",
|
40 |
+
"Processing the data correctly is partly based on the Huggingface Tutorial (https://huggingface.co/course/chapter7/7?fw=pt)"
|
41 |
+
]
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "code",
|
45 |
+
"execution_count": 51,
|
46 |
+
"id": "7b1b2b3e",
|
47 |
+
"metadata": {},
|
48 |
+
"outputs": [],
|
49 |
+
"source": [
|
50 |
+
"tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')"
|
51 |
+
]
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"cell_type": "code",
|
55 |
+
"execution_count": 52,
|
56 |
+
"id": "f276eba7",
|
57 |
+
"metadata": {
|
58 |
+
"scrolled": false
|
59 |
+
},
|
60 |
+
"outputs": [],
|
61 |
+
"source": [
|
62 |
+
" \n",
|
63 |
+
"# create datasets and loaders for training and test set\n",
|
64 |
+
"squad_paths = [str(x) for x in Path('data/training_squad/').glob('**/*.txt')]\n",
|
65 |
+
"nat_paths = [str(x) for x in Path('data/natural_questions_train/').glob('**/*.txt')]\n",
|
66 |
+
"hotpotqa_paths = [str(x) for x in Path('data/hotpotqa_training/').glob('**/*.txt')]"
|
67 |
+
]
|
68 |
+
},
|
69 |
+
{
|
70 |
+
"cell_type": "markdown",
|
71 |
+
"id": "ad8d532a",
|
72 |
+
"metadata": {},
|
73 |
+
"source": [
|
74 |
+
"## POC Model\n",
|
75 |
+
"* Works very well:\n",
|
76 |
+
" * Dropout 0.1 is too small (overfitting after first epoch) - changed to 0.15\n",
|
77 |
+
" * Difference between AdamW and RMSprop minimal\n",
|
78 |
+
" \n",
|
79 |
+
"### Results:\n",
|
80 |
+
"Dropout = 0.15\n",
|
81 |
+
"* Mean EM: 0.5374\n",
|
82 |
+
"* Mean F-1: 0.6826317532406944\n",
|
83 |
+
"\n",
|
84 |
+
"Dropout = 0.2 (overfitting realtively similar to first, but seems to be too high)\n",
|
85 |
+
"* Mean EM: 0.5044\n",
|
86 |
+
"* Mean F-1: 0.6437359169276439"
|
87 |
+
]
|
88 |
+
},
|
89 |
+
{
|
90 |
+
"cell_type": "code",
|
91 |
+
"execution_count": 54,
|
92 |
+
"id": "703e7f38",
|
93 |
+
"metadata": {},
|
94 |
+
"outputs": [],
|
95 |
+
"source": [
|
96 |
+
"dataset = Dataset(squad_paths = squad_paths, natural_question_paths=None, hotpotqa_paths=hotpotqa_paths, tokenizer=tokenizer)\n",
|
97 |
+
"loader = torch.utils.data.DataLoader(dataset, batch_size=8)\n",
|
98 |
+
"\n",
|
99 |
+
"test_dataset = Dataset(squad_paths = [str(x) for x in Path('data/test_squad/').glob('**/*.txt')], \n",
|
100 |
+
" natural_question_paths=None, \n",
|
101 |
+
" hotpotqa_paths = None, tokenizer=tokenizer)\n",
|
102 |
+
"test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4)"
|
103 |
+
]
|
104 |
+
},
|
105 |
+
{
|
106 |
+
"cell_type": "code",
|
107 |
+
"execution_count": 55,
|
108 |
+
"id": "6672f614",
|
109 |
+
"metadata": {},
|
110 |
+
"outputs": [],
|
111 |
+
"source": [
|
112 |
+
"model = DistilBertForMaskedLM.from_pretrained(\"distilbert-base-uncased\")\n",
|
113 |
+
"config = DistilBertConfig.from_pretrained(\"distilbert-base-uncased\")\n",
|
114 |
+
"mod = model.distilbert"
|
115 |
+
]
|
116 |
+
},
|
117 |
+
{
|
118 |
+
"cell_type": "code",
|
119 |
+
"execution_count": 56,
|
120 |
+
"id": "dec15198",
|
121 |
+
"metadata": {},
|
122 |
+
"outputs": [
|
123 |
+
{
|
124 |
+
"data": {
|
125 |
+
"text/plain": [
|
126 |
+
"SimpleQuestionDistilBERT(\n",
|
127 |
+
" (distilbert): DistilBertModel(\n",
|
128 |
+
" (embeddings): Embeddings(\n",
|
129 |
+
" (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
|
130 |
+
" (position_embeddings): Embedding(512, 768)\n",
|
131 |
+
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
132 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
133 |
+
" )\n",
|
134 |
+
" (transformer): Transformer(\n",
|
135 |
+
" (layer): ModuleList(\n",
|
136 |
+
" (0): TransformerBlock(\n",
|
137 |
+
" (attention): MultiHeadSelfAttention(\n",
|
138 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
139 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
140 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
141 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
142 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
143 |
+
" )\n",
|
144 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
145 |
+
" (ffn): FFN(\n",
|
146 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
147 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
148 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
149 |
+
" (activation): GELUActivation()\n",
|
150 |
+
" )\n",
|
151 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
152 |
+
" )\n",
|
153 |
+
" (1): TransformerBlock(\n",
|
154 |
+
" (attention): MultiHeadSelfAttention(\n",
|
155 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
156 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
157 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
158 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
159 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
160 |
+
" )\n",
|
161 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
162 |
+
" (ffn): FFN(\n",
|
163 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
164 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
165 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
166 |
+
" (activation): GELUActivation()\n",
|
167 |
+
" )\n",
|
168 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
169 |
+
" )\n",
|
170 |
+
" (2): TransformerBlock(\n",
|
171 |
+
" (attention): MultiHeadSelfAttention(\n",
|
172 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
173 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
174 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
175 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
176 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
177 |
+
" )\n",
|
178 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
179 |
+
" (ffn): FFN(\n",
|
180 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
181 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
182 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
183 |
+
" (activation): GELUActivation()\n",
|
184 |
+
" )\n",
|
185 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
186 |
+
" )\n",
|
187 |
+
" (3): TransformerBlock(\n",
|
188 |
+
" (attention): MultiHeadSelfAttention(\n",
|
189 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
190 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
191 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
192 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
193 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
194 |
+
" )\n",
|
195 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
196 |
+
" (ffn): FFN(\n",
|
197 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
198 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
199 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
200 |
+
" (activation): GELUActivation()\n",
|
201 |
+
" )\n",
|
202 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
203 |
+
" )\n",
|
204 |
+
" (4): TransformerBlock(\n",
|
205 |
+
" (attention): MultiHeadSelfAttention(\n",
|
206 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
207 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
208 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
209 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
210 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
211 |
+
" )\n",
|
212 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
213 |
+
" (ffn): FFN(\n",
|
214 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
215 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
216 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
217 |
+
" (activation): GELUActivation()\n",
|
218 |
+
" )\n",
|
219 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
220 |
+
" )\n",
|
221 |
+
" (5): TransformerBlock(\n",
|
222 |
+
" (attention): MultiHeadSelfAttention(\n",
|
223 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
224 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
225 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
226 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
227 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
228 |
+
" )\n",
|
229 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
230 |
+
" (ffn): FFN(\n",
|
231 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
232 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
233 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
234 |
+
" (activation): GELUActivation()\n",
|
235 |
+
" )\n",
|
236 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
237 |
+
" )\n",
|
238 |
+
" )\n",
|
239 |
+
" )\n",
|
240 |
+
" )\n",
|
241 |
+
" (dropout): Dropout(p=0.5, inplace=False)\n",
|
242 |
+
" (classifier): Linear(in_features=768, out_features=2, bias=True)\n",
|
243 |
+
")"
|
244 |
+
]
|
245 |
+
},
|
246 |
+
"execution_count": 56,
|
247 |
+
"metadata": {},
|
248 |
+
"output_type": "execute_result"
|
249 |
+
}
|
250 |
+
],
|
251 |
+
"source": [
|
252 |
+
"device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
|
253 |
+
"model = SimpleQuestionDistilBERT(mod)\n",
|
254 |
+
"model.to(device)"
|
255 |
+
]
|
256 |
+
},
|
257 |
+
{
|
258 |
+
"cell_type": "code",
|
259 |
+
"execution_count": 57,
|
260 |
+
"id": "9def3c83",
|
261 |
+
"metadata": {},
|
262 |
+
"outputs": [
|
263 |
+
{
|
264 |
+
"name": "stdout",
|
265 |
+
"output_type": "stream",
|
266 |
+
"text": [
|
267 |
+
"+---------------------------------------------------------+------------+\n",
|
268 |
+
"| Modules | Parameters |\n",
|
269 |
+
"+---------------------------------------------------------+------------+\n",
|
270 |
+
"| distilbert.embeddings.word_embeddings.weight | 23440896 |\n",
|
271 |
+
"| distilbert.embeddings.position_embeddings.weight | 393216 |\n",
|
272 |
+
"| distilbert.embeddings.LayerNorm.weight | 768 |\n",
|
273 |
+
"| distilbert.embeddings.LayerNorm.bias | 768 |\n",
|
274 |
+
"| distilbert.transformer.layer.0.attention.q_lin.weight | 589824 |\n",
|
275 |
+
"| distilbert.transformer.layer.0.attention.q_lin.bias | 768 |\n",
|
276 |
+
"| distilbert.transformer.layer.0.attention.k_lin.weight | 589824 |\n",
|
277 |
+
"| distilbert.transformer.layer.0.attention.k_lin.bias | 768 |\n",
|
278 |
+
"| distilbert.transformer.layer.0.attention.v_lin.weight | 589824 |\n",
|
279 |
+
"| distilbert.transformer.layer.0.attention.v_lin.bias | 768 |\n",
|
280 |
+
"| distilbert.transformer.layer.0.attention.out_lin.weight | 589824 |\n",
|
281 |
+
"| distilbert.transformer.layer.0.attention.out_lin.bias | 768 |\n",
|
282 |
+
"| distilbert.transformer.layer.0.sa_layer_norm.weight | 768 |\n",
|
283 |
+
"| distilbert.transformer.layer.0.sa_layer_norm.bias | 768 |\n",
|
284 |
+
"| distilbert.transformer.layer.0.ffn.lin1.weight | 2359296 |\n",
|
285 |
+
"| distilbert.transformer.layer.0.ffn.lin1.bias | 3072 |\n",
|
286 |
+
"| distilbert.transformer.layer.0.ffn.lin2.weight | 2359296 |\n",
|
287 |
+
"| distilbert.transformer.layer.0.ffn.lin2.bias | 768 |\n",
|
288 |
+
"| distilbert.transformer.layer.0.output_layer_norm.weight | 768 |\n",
|
289 |
+
"| distilbert.transformer.layer.0.output_layer_norm.bias | 768 |\n",
|
290 |
+
"| distilbert.transformer.layer.1.attention.q_lin.weight | 589824 |\n",
|
291 |
+
"| distilbert.transformer.layer.1.attention.q_lin.bias | 768 |\n",
|
292 |
+
"| distilbert.transformer.layer.1.attention.k_lin.weight | 589824 |\n",
|
293 |
+
"| distilbert.transformer.layer.1.attention.k_lin.bias | 768 |\n",
|
294 |
+
"| distilbert.transformer.layer.1.attention.v_lin.weight | 589824 |\n",
|
295 |
+
"| distilbert.transformer.layer.1.attention.v_lin.bias | 768 |\n",
|
296 |
+
"| distilbert.transformer.layer.1.attention.out_lin.weight | 589824 |\n",
|
297 |
+
"| distilbert.transformer.layer.1.attention.out_lin.bias | 768 |\n",
|
298 |
+
"| distilbert.transformer.layer.1.sa_layer_norm.weight | 768 |\n",
|
299 |
+
"| distilbert.transformer.layer.1.sa_layer_norm.bias | 768 |\n",
|
300 |
+
"| distilbert.transformer.layer.1.ffn.lin1.weight | 2359296 |\n",
|
301 |
+
"| distilbert.transformer.layer.1.ffn.lin1.bias | 3072 |\n",
|
302 |
+
"| distilbert.transformer.layer.1.ffn.lin2.weight | 2359296 |\n",
|
303 |
+
"| distilbert.transformer.layer.1.ffn.lin2.bias | 768 |\n",
|
304 |
+
"| distilbert.transformer.layer.1.output_layer_norm.weight | 768 |\n",
|
305 |
+
"| distilbert.transformer.layer.1.output_layer_norm.bias | 768 |\n",
|
306 |
+
"| distilbert.transformer.layer.2.attention.q_lin.weight | 589824 |\n",
|
307 |
+
"| distilbert.transformer.layer.2.attention.q_lin.bias | 768 |\n",
|
308 |
+
"| distilbert.transformer.layer.2.attention.k_lin.weight | 589824 |\n",
|
309 |
+
"| distilbert.transformer.layer.2.attention.k_lin.bias | 768 |\n",
|
310 |
+
"| distilbert.transformer.layer.2.attention.v_lin.weight | 589824 |\n",
|
311 |
+
"| distilbert.transformer.layer.2.attention.v_lin.bias | 768 |\n",
|
312 |
+
"| distilbert.transformer.layer.2.attention.out_lin.weight | 589824 |\n",
|
313 |
+
"| distilbert.transformer.layer.2.attention.out_lin.bias | 768 |\n",
|
314 |
+
"| distilbert.transformer.layer.2.sa_layer_norm.weight | 768 |\n",
|
315 |
+
"| distilbert.transformer.layer.2.sa_layer_norm.bias | 768 |\n",
|
316 |
+
"| distilbert.transformer.layer.2.ffn.lin1.weight | 2359296 |\n",
|
317 |
+
"| distilbert.transformer.layer.2.ffn.lin1.bias | 3072 |\n",
|
318 |
+
"| distilbert.transformer.layer.2.ffn.lin2.weight | 2359296 |\n",
|
319 |
+
"| distilbert.transformer.layer.2.ffn.lin2.bias | 768 |\n",
|
320 |
+
"| distilbert.transformer.layer.2.output_layer_norm.weight | 768 |\n",
|
321 |
+
"| distilbert.transformer.layer.2.output_layer_norm.bias | 768 |\n",
|
322 |
+
"| distilbert.transformer.layer.3.attention.q_lin.weight | 589824 |\n",
|
323 |
+
"| distilbert.transformer.layer.3.attention.q_lin.bias | 768 |\n",
|
324 |
+
"| distilbert.transformer.layer.3.attention.k_lin.weight | 589824 |\n",
|
325 |
+
"| distilbert.transformer.layer.3.attention.k_lin.bias | 768 |\n",
|
326 |
+
"| distilbert.transformer.layer.3.attention.v_lin.weight | 589824 |\n",
|
327 |
+
"| distilbert.transformer.layer.3.attention.v_lin.bias | 768 |\n",
|
328 |
+
"| distilbert.transformer.layer.3.attention.out_lin.weight | 589824 |\n",
|
329 |
+
"| distilbert.transformer.layer.3.attention.out_lin.bias | 768 |\n",
|
330 |
+
"| distilbert.transformer.layer.3.sa_layer_norm.weight | 768 |\n",
|
331 |
+
"| distilbert.transformer.layer.3.sa_layer_norm.bias | 768 |\n",
|
332 |
+
"| distilbert.transformer.layer.3.ffn.lin1.weight | 2359296 |\n",
|
333 |
+
"| distilbert.transformer.layer.3.ffn.lin1.bias | 3072 |\n",
|
334 |
+
"| distilbert.transformer.layer.3.ffn.lin2.weight | 2359296 |\n",
|
335 |
+
"| distilbert.transformer.layer.3.ffn.lin2.bias | 768 |\n",
|
336 |
+
"| distilbert.transformer.layer.3.output_layer_norm.weight | 768 |\n",
|
337 |
+
"| distilbert.transformer.layer.3.output_layer_norm.bias | 768 |\n",
|
338 |
+
"| distilbert.transformer.layer.4.attention.q_lin.weight | 589824 |\n",
|
339 |
+
"| distilbert.transformer.layer.4.attention.q_lin.bias | 768 |\n",
|
340 |
+
"| distilbert.transformer.layer.4.attention.k_lin.weight | 589824 |\n",
|
341 |
+
"| distilbert.transformer.layer.4.attention.k_lin.bias | 768 |\n",
|
342 |
+
"| distilbert.transformer.layer.4.attention.v_lin.weight | 589824 |\n",
|
343 |
+
"| distilbert.transformer.layer.4.attention.v_lin.bias | 768 |\n",
|
344 |
+
"| distilbert.transformer.layer.4.attention.out_lin.weight | 589824 |\n",
|
345 |
+
"| distilbert.transformer.layer.4.attention.out_lin.bias | 768 |\n",
|
346 |
+
"| distilbert.transformer.layer.4.sa_layer_norm.weight | 768 |\n",
|
347 |
+
"| distilbert.transformer.layer.4.sa_layer_norm.bias | 768 |\n",
|
348 |
+
"| distilbert.transformer.layer.4.ffn.lin1.weight | 2359296 |\n",
|
349 |
+
"| distilbert.transformer.layer.4.ffn.lin1.bias | 3072 |\n",
|
350 |
+
"| distilbert.transformer.layer.4.ffn.lin2.weight | 2359296 |\n",
|
351 |
+
"| distilbert.transformer.layer.4.ffn.lin2.bias | 768 |\n",
|
352 |
+
"| distilbert.transformer.layer.4.output_layer_norm.weight | 768 |\n",
|
353 |
+
"| distilbert.transformer.layer.4.output_layer_norm.bias | 768 |\n",
|
354 |
+
"| distilbert.transformer.layer.5.attention.q_lin.weight | 589824 |\n",
|
355 |
+
"| distilbert.transformer.layer.5.attention.q_lin.bias | 768 |\n",
|
356 |
+
"| distilbert.transformer.layer.5.attention.k_lin.weight | 589824 |\n",
|
357 |
+
"| distilbert.transformer.layer.5.attention.k_lin.bias | 768 |\n",
|
358 |
+
"| distilbert.transformer.layer.5.attention.v_lin.weight | 589824 |\n",
|
359 |
+
"| distilbert.transformer.layer.5.attention.v_lin.bias | 768 |\n",
|
360 |
+
"| distilbert.transformer.layer.5.attention.out_lin.weight | 589824 |\n",
|
361 |
+
"| distilbert.transformer.layer.5.attention.out_lin.bias | 768 |\n",
|
362 |
+
"| distilbert.transformer.layer.5.sa_layer_norm.weight | 768 |\n",
|
363 |
+
"| distilbert.transformer.layer.5.sa_layer_norm.bias | 768 |\n",
|
364 |
+
"| distilbert.transformer.layer.5.ffn.lin1.weight | 2359296 |\n",
|
365 |
+
"| distilbert.transformer.layer.5.ffn.lin1.bias | 3072 |\n",
|
366 |
+
"| distilbert.transformer.layer.5.ffn.lin2.weight | 2359296 |\n",
|
367 |
+
"| distilbert.transformer.layer.5.ffn.lin2.bias | 768 |\n",
|
368 |
+
"| distilbert.transformer.layer.5.output_layer_norm.weight | 768 |\n",
|
369 |
+
"| distilbert.transformer.layer.5.output_layer_norm.bias | 768 |\n",
|
370 |
+
"| classifier.weight | 1536 |\n",
|
371 |
+
"| classifier.bias | 2 |\n",
|
372 |
+
"+---------------------------------------------------------+------------+\n",
|
373 |
+
"Total Trainable Params: 66364418\n"
|
374 |
+
]
|
375 |
+
},
|
376 |
+
{
|
377 |
+
"data": {
|
378 |
+
"text/plain": [
|
379 |
+
"66364418"
|
380 |
+
]
|
381 |
+
},
|
382 |
+
"execution_count": 57,
|
383 |
+
"metadata": {},
|
384 |
+
"output_type": "execute_result"
|
385 |
+
}
|
386 |
+
],
|
387 |
+
"source": [
|
388 |
+
"count_parameters(model)"
|
389 |
+
]
|
390 |
+
},
|
391 |
+
{
|
392 |
+
"cell_type": "markdown",
|
393 |
+
"id": "426a6311",
|
394 |
+
"metadata": {},
|
395 |
+
"source": [
|
396 |
+
"### Testing the model"
|
397 |
+
]
|
398 |
+
},
|
399 |
+
{
|
400 |
+
"cell_type": "code",
|
401 |
+
"execution_count": 58,
|
402 |
+
"id": "6151c201",
|
403 |
+
"metadata": {},
|
404 |
+
"outputs": [],
|
405 |
+
"source": [
|
406 |
+
"# get smaller dataset\n",
|
407 |
+
"batch_size = 8\n",
|
408 |
+
"test_ds = Dataset(squad_paths = squad_paths[:2], natural_question_paths=None, hotpotqa_paths=None, tokenizer=tokenizer)\n",
|
409 |
+
"test_ds_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)\n",
|
410 |
+
"optim = RMSprop(model.parameters(), lr=1e-4)"
|
411 |
+
]
|
412 |
+
},
|
413 |
+
{
|
414 |
+
"cell_type": "code",
|
415 |
+
"execution_count": 59,
|
416 |
+
"id": "aeae0c56",
|
417 |
+
"metadata": {},
|
418 |
+
"outputs": [
|
419 |
+
{
|
420 |
+
"name": "stdout",
|
421 |
+
"output_type": "stream",
|
422 |
+
"text": [
|
423 |
+
"Passed\n"
|
424 |
+
]
|
425 |
+
}
|
426 |
+
],
|
427 |
+
"source": [
|
428 |
+
"test_model(model, optim, test_ds_loader, device)"
|
429 |
+
]
|
430 |
+
},
|
431 |
+
{
|
432 |
+
"cell_type": "markdown",
|
433 |
+
"id": "59928d34",
|
434 |
+
"metadata": {},
|
435 |
+
"source": [
|
436 |
+
"### Model Training"
|
437 |
+
]
|
438 |
+
},
|
439 |
+
{
|
440 |
+
"cell_type": "code",
|
441 |
+
"execution_count": 60,
|
442 |
+
"id": "a8017b8c",
|
443 |
+
"metadata": {},
|
444 |
+
"outputs": [
|
445 |
+
{
|
446 |
+
"data": {
|
447 |
+
"text/plain": [
|
448 |
+
"SimpleQuestionDistilBERT(\n",
|
449 |
+
" (distilbert): DistilBertModel(\n",
|
450 |
+
" (embeddings): Embeddings(\n",
|
451 |
+
" (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
|
452 |
+
" (position_embeddings): Embedding(512, 768)\n",
|
453 |
+
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
454 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
455 |
+
" )\n",
|
456 |
+
" (transformer): Transformer(\n",
|
457 |
+
" (layer): ModuleList(\n",
|
458 |
+
" (0): TransformerBlock(\n",
|
459 |
+
" (attention): MultiHeadSelfAttention(\n",
|
460 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
461 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
462 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
463 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
464 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
465 |
+
" )\n",
|
466 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
467 |
+
" (ffn): FFN(\n",
|
468 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
469 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
470 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
471 |
+
" (activation): GELUActivation()\n",
|
472 |
+
" )\n",
|
473 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
474 |
+
" )\n",
|
475 |
+
" (1): TransformerBlock(\n",
|
476 |
+
" (attention): MultiHeadSelfAttention(\n",
|
477 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
478 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
479 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
480 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
481 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
482 |
+
" )\n",
|
483 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
484 |
+
" (ffn): FFN(\n",
|
485 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
486 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
487 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
488 |
+
" (activation): GELUActivation()\n",
|
489 |
+
" )\n",
|
490 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
491 |
+
" )\n",
|
492 |
+
" (2): TransformerBlock(\n",
|
493 |
+
" (attention): MultiHeadSelfAttention(\n",
|
494 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
495 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
496 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
497 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
498 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
499 |
+
" )\n",
|
500 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
501 |
+
" (ffn): FFN(\n",
|
502 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
503 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
504 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
505 |
+
" (activation): GELUActivation()\n",
|
506 |
+
" )\n",
|
507 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
508 |
+
" )\n",
|
509 |
+
" (3): TransformerBlock(\n",
|
510 |
+
" (attention): MultiHeadSelfAttention(\n",
|
511 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
512 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
513 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
514 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
515 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
516 |
+
" )\n",
|
517 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
518 |
+
" (ffn): FFN(\n",
|
519 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
520 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
521 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
522 |
+
" (activation): GELUActivation()\n",
|
523 |
+
" )\n",
|
524 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
525 |
+
" )\n",
|
526 |
+
" (4): TransformerBlock(\n",
|
527 |
+
" (attention): MultiHeadSelfAttention(\n",
|
528 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
529 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
530 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
531 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
532 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
533 |
+
" )\n",
|
534 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
535 |
+
" (ffn): FFN(\n",
|
536 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
537 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
538 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
539 |
+
" (activation): GELUActivation()\n",
|
540 |
+
" )\n",
|
541 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
542 |
+
" )\n",
|
543 |
+
" (5): TransformerBlock(\n",
|
544 |
+
" (attention): MultiHeadSelfAttention(\n",
|
545 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
546 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
547 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
548 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
549 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
550 |
+
" )\n",
|
551 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
552 |
+
" (ffn): FFN(\n",
|
553 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
554 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
555 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
556 |
+
" (activation): GELUActivation()\n",
|
557 |
+
" )\n",
|
558 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
559 |
+
" )\n",
|
560 |
+
" )\n",
|
561 |
+
" )\n",
|
562 |
+
" )\n",
|
563 |
+
" (dropout): Dropout(p=0.5, inplace=False)\n",
|
564 |
+
" (classifier): Linear(in_features=768, out_features=2, bias=True)\n",
|
565 |
+
")"
|
566 |
+
]
|
567 |
+
},
|
568 |
+
"execution_count": 60,
|
569 |
+
"metadata": {},
|
570 |
+
"output_type": "execute_result"
|
571 |
+
}
|
572 |
+
],
|
573 |
+
"source": [
|
574 |
+
"device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
|
575 |
+
"model = SimpleQuestionDistilBERT(mod)\n",
|
576 |
+
"model.to(device)"
|
577 |
+
]
|
578 |
+
},
|
579 |
+
{
|
580 |
+
"cell_type": "code",
|
581 |
+
"execution_count": 61,
|
582 |
+
"id": "f13c12dc",
|
583 |
+
"metadata": {},
|
584 |
+
"outputs": [],
|
585 |
+
"source": [
|
586 |
+
"model.train()\n",
|
587 |
+
"optim = RMSprop(model.parameters(), lr=1e-4)"
|
588 |
+
]
|
589 |
+
},
|
590 |
+
{
|
591 |
+
"cell_type": "code",
|
592 |
+
"execution_count": 22,
|
593 |
+
"id": "e4fa54d9",
|
594 |
+
"metadata": {},
|
595 |
+
"outputs": [
|
596 |
+
{
|
597 |
+
"data": {
|
598 |
+
"application/vnd.jupyter.widget-view+json": {
|
599 |
+
"model_id": "0016d9f5ba764eb98e9df8573995c86c",
|
600 |
+
"version_major": 2,
|
601 |
+
"version_minor": 0
|
602 |
+
},
|
603 |
+
"text/plain": [
|
604 |
+
" 0%| | 0/10875 [00:00<?, ?it/s]"
|
605 |
+
]
|
606 |
+
},
|
607 |
+
"metadata": {},
|
608 |
+
"output_type": "display_data"
|
609 |
+
},
|
610 |
+
{
|
611 |
+
"name": "stdout",
|
612 |
+
"output_type": "stream",
|
613 |
+
"text": [
|
614 |
+
"Mean Training Error 0.7555404769408292\n"
|
615 |
+
]
|
616 |
+
},
|
617 |
+
{
|
618 |
+
"data": {
|
619 |
+
"application/vnd.jupyter.widget-view+json": {
|
620 |
+
"model_id": "96af0e22e2ee44fd920795b0e7317839",
|
621 |
+
"version_major": 2,
|
622 |
+
"version_minor": 0
|
623 |
+
},
|
624 |
+
"text/plain": [
|
625 |
+
" 0%| | 0/2500 [00:00<?, ?it/s]"
|
626 |
+
]
|
627 |
+
},
|
628 |
+
"metadata": {},
|
629 |
+
"output_type": "display_data"
|
630 |
+
},
|
631 |
+
{
|
632 |
+
"name": "stdout",
|
633 |
+
"output_type": "stream",
|
634 |
+
"text": [
|
635 |
+
"Mean Test Error 1.761920437876694\n"
|
636 |
+
]
|
637 |
+
},
|
638 |
+
{
|
639 |
+
"data": {
|
640 |
+
"application/vnd.jupyter.widget-view+json": {
|
641 |
+
"model_id": "5160ffe5f60e4b72b46746a33b1d60d0",
|
642 |
+
"version_major": 2,
|
643 |
+
"version_minor": 0
|
644 |
+
},
|
645 |
+
"text/plain": [
|
646 |
+
" 0%| | 0/10875 [00:00<?, ?it/s]"
|
647 |
+
]
|
648 |
+
},
|
649 |
+
"metadata": {},
|
650 |
+
"output_type": "display_data"
|
651 |
+
},
|
652 |
+
{
|
653 |
+
"ename": "KeyboardInterrupt",
|
654 |
+
"evalue": "",
|
655 |
+
"output_type": "error",
|
656 |
+
"traceback": [
|
657 |
+
"\u001B[0;31m---------------------------------------------------------------------------\u001B[0m",
|
658 |
+
"\u001B[0;31mKeyboardInterrupt\u001B[0m Traceback (most recent call last)",
|
659 |
+
"Cell \u001B[0;32mIn [22], line 18\u001B[0m\n\u001B[1;32m 16\u001B[0m \u001B[38;5;66;03m# print(torch.argmax(outputs['start_logits'],axis=1), torch.argmax(outputs['end_logits'], axis=1), start, end)\u001B[39;00m\n\u001B[1;32m 17\u001B[0m loss \u001B[38;5;241m=\u001B[39m outputs[\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mloss\u001B[39m\u001B[38;5;124m'\u001B[39m]\n\u001B[0;32m---> 18\u001B[0m loss\u001B[38;5;241m.\u001B[39mbackward()\n\u001B[1;32m 19\u001B[0m \u001B[38;5;66;03m# torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)\u001B[39;00m\n\u001B[1;32m 20\u001B[0m optim\u001B[38;5;241m.\u001B[39mstep()\n",
|
660 |
+
"File \u001B[0;32m~/Documents/University/WS2022/applieddl/venv/lib64/python3.10/site-packages/torch/_tensor.py:396\u001B[0m, in \u001B[0;36mTensor.backward\u001B[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001B[0m\n\u001B[1;32m 387\u001B[0m \u001B[38;5;28;01mif\u001B[39;00m has_torch_function_unary(\u001B[38;5;28mself\u001B[39m):\n\u001B[1;32m 388\u001B[0m \u001B[38;5;28;01mreturn\u001B[39;00m handle_torch_function(\n\u001B[1;32m 389\u001B[0m Tensor\u001B[38;5;241m.\u001B[39mbackward,\n\u001B[1;32m 390\u001B[0m (\u001B[38;5;28mself\u001B[39m,),\n\u001B[0;32m (...)\u001B[0m\n\u001B[1;32m 394\u001B[0m create_graph\u001B[38;5;241m=\u001B[39mcreate_graph,\n\u001B[1;32m 395\u001B[0m inputs\u001B[38;5;241m=\u001B[39minputs)\n\u001B[0;32m--> 396\u001B[0m \u001B[43mtorch\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mautograd\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mbackward\u001B[49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mgradient\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mretain_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcreate_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43minputs\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[43minputs\u001B[49m\u001B[43m)\u001B[49m\n",
|
661 |
+
"File \u001B[0;32m~/Documents/University/WS2022/applieddl/venv/lib64/python3.10/site-packages/torch/autograd/__init__.py:173\u001B[0m, in \u001B[0;36mbackward\u001B[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001B[0m\n\u001B[1;32m 168\u001B[0m retain_graph \u001B[38;5;241m=\u001B[39m create_graph\n\u001B[1;32m 170\u001B[0m \u001B[38;5;66;03m# The reason we repeat same the comment below is that\u001B[39;00m\n\u001B[1;32m 171\u001B[0m \u001B[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001B[39;00m\n\u001B[1;32m 172\u001B[0m \u001B[38;5;66;03m# calls in the traceback and some print out the last line\u001B[39;00m\n\u001B[0;32m--> 173\u001B[0m \u001B[43mVariable\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43m_execution_engine\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mrun_backward\u001B[49m\u001B[43m(\u001B[49m\u001B[43m \u001B[49m\u001B[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001B[39;49;00m\n\u001B[1;32m 174\u001B[0m \u001B[43m \u001B[49m\u001B[43mtensors\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mgrad_tensors_\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mretain_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mcreate_graph\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43minputs\u001B[49m\u001B[43m,\u001B[49m\n\u001B[1;32m 175\u001B[0m \u001B[43m \u001B[49m\u001B[43mallow_unreachable\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mTrue\u001B[39;49;00m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43maccumulate_grad\u001B[49m\u001B[38;5;241;43m=\u001B[39;49m\u001B[38;5;28;43;01mTrue\u001B[39;49;00m\u001B[43m)\u001B[49m\n",
|
662 |
+
"\u001B[0;31mKeyboardInterrupt\u001B[0m: "
|
663 |
+
]
|
664 |
+
}
|
665 |
+
],
|
666 |
+
"source": [
|
667 |
+
"epochs = 5\n",
|
668 |
+
"\n",
|
669 |
+
"for epoch in range(epochs):\n",
|
670 |
+
" loop = tqdm(loader, leave=True)\n",
|
671 |
+
" model.train()\n",
|
672 |
+
" mean_training_error = []\n",
|
673 |
+
" for batch in loop:\n",
|
674 |
+
" optim.zero_grad()\n",
|
675 |
+
" \n",
|
676 |
+
" input_ids = batch['input_ids'].to(device)\n",
|
677 |
+
" attention_mask = batch['attention_mask'].to(device)\n",
|
678 |
+
" start = batch['start_positions'].to(device)\n",
|
679 |
+
" end = batch['end_positions'].to(device)\n",
|
680 |
+
" \n",
|
681 |
+
" outputs = model(input_ids, attention_mask=attention_mask, start_positions=start, end_positions=end)\n",
|
682 |
+
" # print(torch.argmax(outputs['start_logits'],axis=1), torch.argmax(outputs['end_logits'], axis=1), start, end)\n",
|
683 |
+
" loss = outputs['loss']\n",
|
684 |
+
" loss.backward()\n",
|
685 |
+
" # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)\n",
|
686 |
+
" optim.step()\n",
|
687 |
+
" mean_training_error.append(loss.item())\n",
|
688 |
+
" loop.set_description(f'Epoch {epoch}')\n",
|
689 |
+
" loop.set_postfix(loss=loss.item())\n",
|
690 |
+
" print(\"Mean Training Error\", np.mean(mean_training_error))\n",
|
691 |
+
" \n",
|
692 |
+
" \n",
|
693 |
+
" loop = tqdm(test_loader, leave=True)\n",
|
694 |
+
" model.eval()\n",
|
695 |
+
" mean_test_error = []\n",
|
696 |
+
" for batch in loop:\n",
|
697 |
+
" \n",
|
698 |
+
" input_ids = batch['input_ids'].to(device)\n",
|
699 |
+
" attention_mask = batch['attention_mask'].to(device)\n",
|
700 |
+
" start = batch['start_positions'].to(device)\n",
|
701 |
+
" end = batch['end_positions'].to(device)\n",
|
702 |
+
" \n",
|
703 |
+
" outputs = model(input_ids, attention_mask=attention_mask, start_positions=start, end_positions=end)\n",
|
704 |
+
" # print(torch.argmax(outputs['start_logits'],axis=1), torch.argmax(outputs['end_logits'], axis=1), start, end)\n",
|
705 |
+
" loss = outputs['loss']\n",
|
706 |
+
" \n",
|
707 |
+
" mean_test_error.append(loss.item())\n",
|
708 |
+
" loop.set_description(f'Epoch {epoch} Testset')\n",
|
709 |
+
" loop.set_postfix(loss=loss.item())\n",
|
710 |
+
" print(\"Mean Test Error\", np.mean(mean_test_error))"
|
711 |
+
]
|
712 |
+
},
|
713 |
+
{
|
714 |
+
"cell_type": "code",
|
715 |
+
"execution_count": 19,
|
716 |
+
"id": "6ff26fb4",
|
717 |
+
"metadata": {},
|
718 |
+
"outputs": [],
|
719 |
+
"source": [
|
720 |
+
"torch.save(model.state_dict(), \"simple_distilbert_qa.model\")"
|
721 |
+
]
|
722 |
+
},
|
723 |
+
{
|
724 |
+
"cell_type": "code",
|
725 |
+
"execution_count": 20,
|
726 |
+
"id": "a5e7abeb",
|
727 |
+
"metadata": {},
|
728 |
+
"outputs": [
|
729 |
+
{
|
730 |
+
"data": {
|
731 |
+
"text/plain": [
|
732 |
+
"<All keys matched successfully>"
|
733 |
+
]
|
734 |
+
},
|
735 |
+
"execution_count": 20,
|
736 |
+
"metadata": {},
|
737 |
+
"output_type": "execute_result"
|
738 |
+
}
|
739 |
+
],
|
740 |
+
"source": [
|
741 |
+
"model = SimpleQuestionDistilBERT(mod)\n",
|
742 |
+
"model.load_state_dict(torch.load(\"simple_distilbert_qa.model\"))"
|
743 |
+
]
|
744 |
+
},
|
745 |
+
{
|
746 |
+
"cell_type": "code",
|
747 |
+
"execution_count": 18,
|
748 |
+
"id": "f5ad7bee",
|
749 |
+
"metadata": {},
|
750 |
+
"outputs": [
|
751 |
+
{
|
752 |
+
"name": "stderr",
|
753 |
+
"output_type": "stream",
|
754 |
+
"text": [
|
755 |
+
"100%|鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅| 2500/2500 [02:09<00:00, 19.37it/s]"
|
756 |
+
]
|
757 |
+
},
|
758 |
+
{
|
759 |
+
"name": "stdout",
|
760 |
+
"output_type": "stream",
|
761 |
+
"text": [
|
762 |
+
"Mean EM: 0.5374\n",
|
763 |
+
"Mean F-1: 0.6826317532406944\n"
|
764 |
+
]
|
765 |
+
},
|
766 |
+
{
|
767 |
+
"name": "stderr",
|
768 |
+
"output_type": "stream",
|
769 |
+
"text": [
|
770 |
+
"\n"
|
771 |
+
]
|
772 |
+
}
|
773 |
+
],
|
774 |
+
"source": [
|
775 |
+
"eval_test_set(model, tokenizer, test_loader, device)"
|
776 |
+
]
|
777 |
+
},
|
778 |
+
{
|
779 |
+
"cell_type": "markdown",
|
780 |
+
"id": "fa6017a8",
|
781 |
+
"metadata": {},
|
782 |
+
"source": [
|
783 |
+
"## Freeze baseline and train new head\n",
|
784 |
+
"This was my initial idea, to freeze the layers and add a completely new head, which we train from scratch. I tried a lot of different configurations, but nothing really worked, I usually stayed at a CrossEntropyLoss of about 3 the whole time. Below, you can see the different heads I have tried.\n",
|
785 |
+
"\n",
|
786 |
+
"Furthermore, I experimented with different data, because I though it might not be enough data all in all. I would conclude that this didn't work because (1) Transformers are very data-hungry and I probably still used too little data (one epoch took about 1h though, so it wasn't possible to use even more). (2) We train the layers completely new, which means they contain absolutely no structure about the problem and task beforehand. I do not think that this way of training leads to better results / less energy used all in all, because it would be too resource intense.\n",
|
787 |
+
"\n",
|
788 |
+
"The following setup is partly based on the HuggingFace implementation of the question answering model (https://github.com/huggingface/transformers/blob/v4.23.1/src/transformers/models/distilbert/modeling_distilbert.py#L805)"
|
789 |
+
]
|
790 |
+
},
|
791 |
+
{
|
792 |
+
"cell_type": "code",
|
793 |
+
"execution_count": 62,
|
794 |
+
"id": "92b21967",
|
795 |
+
"metadata": {},
|
796 |
+
"outputs": [],
|
797 |
+
"source": [
|
798 |
+
"model = DistilBertForMaskedLM.from_pretrained(\"distilbert-base-uncased\")"
|
799 |
+
]
|
800 |
+
},
|
801 |
+
{
|
802 |
+
"cell_type": "code",
|
803 |
+
"execution_count": 63,
|
804 |
+
"id": "1d7b3a8c",
|
805 |
+
"metadata": {},
|
806 |
+
"outputs": [],
|
807 |
+
"source": [
|
808 |
+
"config = DistilBertConfig.from_pretrained(\"distilbert-base-uncased\")"
|
809 |
+
]
|
810 |
+
},
|
811 |
+
{
|
812 |
+
"cell_type": "code",
|
813 |
+
"execution_count": 64,
|
814 |
+
"id": "91444894",
|
815 |
+
"metadata": {},
|
816 |
+
"outputs": [],
|
817 |
+
"source": [
|
818 |
+
"# only take base model, we do not need the classification head\n",
|
819 |
+
"mod = model.distilbert"
|
820 |
+
]
|
821 |
+
},
|
822 |
+
{
|
823 |
+
"cell_type": "code",
|
824 |
+
"execution_count": 65,
|
825 |
+
"id": "74ca6c07",
|
826 |
+
"metadata": {},
|
827 |
+
"outputs": [
|
828 |
+
{
|
829 |
+
"data": {
|
830 |
+
"text/plain": [
|
831 |
+
"QuestionDistilBERT(\n",
|
832 |
+
" (distilbert): DistilBertModel(\n",
|
833 |
+
" (embeddings): Embeddings(\n",
|
834 |
+
" (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
|
835 |
+
" (position_embeddings): Embedding(512, 768)\n",
|
836 |
+
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
837 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
838 |
+
" )\n",
|
839 |
+
" (transformer): Transformer(\n",
|
840 |
+
" (layer): ModuleList(\n",
|
841 |
+
" (0): TransformerBlock(\n",
|
842 |
+
" (attention): MultiHeadSelfAttention(\n",
|
843 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
844 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
845 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
846 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
847 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
848 |
+
" )\n",
|
849 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
850 |
+
" (ffn): FFN(\n",
|
851 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
852 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
853 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
854 |
+
" (activation): GELUActivation()\n",
|
855 |
+
" )\n",
|
856 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
857 |
+
" )\n",
|
858 |
+
" (1): TransformerBlock(\n",
|
859 |
+
" (attention): MultiHeadSelfAttention(\n",
|
860 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
861 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
862 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
863 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
864 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
865 |
+
" )\n",
|
866 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
867 |
+
" (ffn): FFN(\n",
|
868 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
869 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
870 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
871 |
+
" (activation): GELUActivation()\n",
|
872 |
+
" )\n",
|
873 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
874 |
+
" )\n",
|
875 |
+
" (2): TransformerBlock(\n",
|
876 |
+
" (attention): MultiHeadSelfAttention(\n",
|
877 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
878 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
879 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
880 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
881 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
882 |
+
" )\n",
|
883 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
884 |
+
" (ffn): FFN(\n",
|
885 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
886 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
887 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
888 |
+
" (activation): GELUActivation()\n",
|
889 |
+
" )\n",
|
890 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
891 |
+
" )\n",
|
892 |
+
" (3): TransformerBlock(\n",
|
893 |
+
" (attention): MultiHeadSelfAttention(\n",
|
894 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
895 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
896 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
897 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
898 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
899 |
+
" )\n",
|
900 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
901 |
+
" (ffn): FFN(\n",
|
902 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
903 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
904 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
905 |
+
" (activation): GELUActivation()\n",
|
906 |
+
" )\n",
|
907 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
908 |
+
" )\n",
|
909 |
+
" (4): TransformerBlock(\n",
|
910 |
+
" (attention): MultiHeadSelfAttention(\n",
|
911 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
912 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
913 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
914 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
915 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
916 |
+
" )\n",
|
917 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
918 |
+
" (ffn): FFN(\n",
|
919 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
920 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
921 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
922 |
+
" (activation): GELUActivation()\n",
|
923 |
+
" )\n",
|
924 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
925 |
+
" )\n",
|
926 |
+
" (5): TransformerBlock(\n",
|
927 |
+
" (attention): MultiHeadSelfAttention(\n",
|
928 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
929 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
930 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
931 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
932 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
933 |
+
" )\n",
|
934 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
935 |
+
" (ffn): FFN(\n",
|
936 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
937 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
938 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
939 |
+
" (activation): GELUActivation()\n",
|
940 |
+
" )\n",
|
941 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
942 |
+
" )\n",
|
943 |
+
" )\n",
|
944 |
+
" )\n",
|
945 |
+
" )\n",
|
946 |
+
" (relu): ReLU()\n",
|
947 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
948 |
+
" (te): TransformerEncoder(\n",
|
949 |
+
" (layers): ModuleList(\n",
|
950 |
+
" (0): TransformerEncoderLayer(\n",
|
951 |
+
" (self_attn): MultiheadAttention(\n",
|
952 |
+
" (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n",
|
953 |
+
" )\n",
|
954 |
+
" (linear1): Linear(in_features=768, out_features=2048, bias=True)\n",
|
955 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
956 |
+
" (linear2): Linear(in_features=2048, out_features=768, bias=True)\n",
|
957 |
+
" (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
|
958 |
+
" (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
|
959 |
+
" (dropout1): Dropout(p=0.1, inplace=False)\n",
|
960 |
+
" (dropout2): Dropout(p=0.1, inplace=False)\n",
|
961 |
+
" )\n",
|
962 |
+
" (1): TransformerEncoderLayer(\n",
|
963 |
+
" (self_attn): MultiheadAttention(\n",
|
964 |
+
" (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n",
|
965 |
+
" )\n",
|
966 |
+
" (linear1): Linear(in_features=768, out_features=2048, bias=True)\n",
|
967 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
968 |
+
" (linear2): Linear(in_features=2048, out_features=768, bias=True)\n",
|
969 |
+
" (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
|
970 |
+
" (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
|
971 |
+
" (dropout1): Dropout(p=0.1, inplace=False)\n",
|
972 |
+
" (dropout2): Dropout(p=0.1, inplace=False)\n",
|
973 |
+
" )\n",
|
974 |
+
" (2): TransformerEncoderLayer(\n",
|
975 |
+
" (self_attn): MultiheadAttention(\n",
|
976 |
+
" (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)\n",
|
977 |
+
" )\n",
|
978 |
+
" (linear1): Linear(in_features=768, out_features=2048, bias=True)\n",
|
979 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
980 |
+
" (linear2): Linear(in_features=2048, out_features=768, bias=True)\n",
|
981 |
+
" (norm1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
|
982 |
+
" (norm2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)\n",
|
983 |
+
" (dropout1): Dropout(p=0.1, inplace=False)\n",
|
984 |
+
" (dropout2): Dropout(p=0.1, inplace=False)\n",
|
985 |
+
" )\n",
|
986 |
+
" )\n",
|
987 |
+
" )\n",
|
988 |
+
" (classifier): Sequential(\n",
|
989 |
+
" (0): Dropout(p=0.1, inplace=False)\n",
|
990 |
+
" (1): ReLU()\n",
|
991 |
+
" (2): Linear(in_features=768, out_features=512, bias=True)\n",
|
992 |
+
" (3): Dropout(p=0.1, inplace=False)\n",
|
993 |
+
" (4): ReLU()\n",
|
994 |
+
" (5): Linear(in_features=512, out_features=256, bias=True)\n",
|
995 |
+
" (6): Dropout(p=0.1, inplace=False)\n",
|
996 |
+
" (7): ReLU()\n",
|
997 |
+
" (8): Linear(in_features=256, out_features=128, bias=True)\n",
|
998 |
+
" (9): Dropout(p=0.1, inplace=False)\n",
|
999 |
+
" (10): ReLU()\n",
|
1000 |
+
" (11): Linear(in_features=128, out_features=64, bias=True)\n",
|
1001 |
+
" (12): Dropout(p=0.1, inplace=False)\n",
|
1002 |
+
" (13): ReLU()\n",
|
1003 |
+
" (14): Linear(in_features=64, out_features=2, bias=True)\n",
|
1004 |
+
" )\n",
|
1005 |
+
")"
|
1006 |
+
]
|
1007 |
+
},
|
1008 |
+
"execution_count": 65,
|
1009 |
+
"metadata": {},
|
1010 |
+
"output_type": "execute_result"
|
1011 |
+
}
|
1012 |
+
],
|
1013 |
+
"source": [
|
1014 |
+
"device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
|
1015 |
+
"model = QuestionDistilBERT(mod)\n",
|
1016 |
+
"model.to(device)"
|
1017 |
+
]
|
1018 |
+
},
|
1019 |
+
{
|
1020 |
+
"cell_type": "code",
|
1021 |
+
"execution_count": 66,
|
1022 |
+
"id": "340857f9",
|
1023 |
+
"metadata": {},
|
1024 |
+
"outputs": [
|
1025 |
+
{
|
1026 |
+
"name": "stdout",
|
1027 |
+
"output_type": "stream",
|
1028 |
+
"text": [
|
1029 |
+
"+---------------------------------------+------------+\n",
|
1030 |
+
"| Modules | Parameters |\n",
|
1031 |
+
"+---------------------------------------+------------+\n",
|
1032 |
+
"| te.layers.0.self_attn.in_proj_weight | 1769472 |\n",
|
1033 |
+
"| te.layers.0.self_attn.in_proj_bias | 2304 |\n",
|
1034 |
+
"| te.layers.0.self_attn.out_proj.weight | 589824 |\n",
|
1035 |
+
"| te.layers.0.self_attn.out_proj.bias | 768 |\n",
|
1036 |
+
"| te.layers.0.linear1.weight | 1572864 |\n",
|
1037 |
+
"| te.layers.0.linear1.bias | 2048 |\n",
|
1038 |
+
"| te.layers.0.linear2.weight | 1572864 |\n",
|
1039 |
+
"| te.layers.0.linear2.bias | 768 |\n",
|
1040 |
+
"| te.layers.0.norm1.weight | 768 |\n",
|
1041 |
+
"| te.layers.0.norm1.bias | 768 |\n",
|
1042 |
+
"| te.layers.0.norm2.weight | 768 |\n",
|
1043 |
+
"| te.layers.0.norm2.bias | 768 |\n",
|
1044 |
+
"| te.layers.1.self_attn.in_proj_weight | 1769472 |\n",
|
1045 |
+
"| te.layers.1.self_attn.in_proj_bias | 2304 |\n",
|
1046 |
+
"| te.layers.1.self_attn.out_proj.weight | 589824 |\n",
|
1047 |
+
"| te.layers.1.self_attn.out_proj.bias | 768 |\n",
|
1048 |
+
"| te.layers.1.linear1.weight | 1572864 |\n",
|
1049 |
+
"| te.layers.1.linear1.bias | 2048 |\n",
|
1050 |
+
"| te.layers.1.linear2.weight | 1572864 |\n",
|
1051 |
+
"| te.layers.1.linear2.bias | 768 |\n",
|
1052 |
+
"| te.layers.1.norm1.weight | 768 |\n",
|
1053 |
+
"| te.layers.1.norm1.bias | 768 |\n",
|
1054 |
+
"| te.layers.1.norm2.weight | 768 |\n",
|
1055 |
+
"| te.layers.1.norm2.bias | 768 |\n",
|
1056 |
+
"| te.layers.2.self_attn.in_proj_weight | 1769472 |\n",
|
1057 |
+
"| te.layers.2.self_attn.in_proj_bias | 2304 |\n",
|
1058 |
+
"| te.layers.2.self_attn.out_proj.weight | 589824 |\n",
|
1059 |
+
"| te.layers.2.self_attn.out_proj.bias | 768 |\n",
|
1060 |
+
"| te.layers.2.linear1.weight | 1572864 |\n",
|
1061 |
+
"| te.layers.2.linear1.bias | 2048 |\n",
|
1062 |
+
"| te.layers.2.linear2.weight | 1572864 |\n",
|
1063 |
+
"| te.layers.2.linear2.bias | 768 |\n",
|
1064 |
+
"| te.layers.2.norm1.weight | 768 |\n",
|
1065 |
+
"| te.layers.2.norm1.bias | 768 |\n",
|
1066 |
+
"| te.layers.2.norm2.weight | 768 |\n",
|
1067 |
+
"| te.layers.2.norm2.bias | 768 |\n",
|
1068 |
+
"| classifier.2.weight | 393216 |\n",
|
1069 |
+
"| classifier.2.bias | 512 |\n",
|
1070 |
+
"| classifier.5.weight | 131072 |\n",
|
1071 |
+
"| classifier.5.bias | 256 |\n",
|
1072 |
+
"| classifier.8.weight | 32768 |\n",
|
1073 |
+
"| classifier.8.bias | 128 |\n",
|
1074 |
+
"| classifier.11.weight | 8192 |\n",
|
1075 |
+
"| classifier.11.bias | 64 |\n",
|
1076 |
+
"| classifier.14.weight | 128 |\n",
|
1077 |
+
"| classifier.14.bias | 2 |\n",
|
1078 |
+
"+---------------------------------------+------------+\n",
|
1079 |
+
"Total Trainable Params: 17108290\n"
|
1080 |
+
]
|
1081 |
+
},
|
1082 |
+
{
|
1083 |
+
"data": {
|
1084 |
+
"text/plain": [
|
1085 |
+
"17108290"
|
1086 |
+
]
|
1087 |
+
},
|
1088 |
+
"execution_count": 66,
|
1089 |
+
"metadata": {},
|
1090 |
+
"output_type": "execute_result"
|
1091 |
+
}
|
1092 |
+
],
|
1093 |
+
"source": [
|
1094 |
+
"count_parameters(model)"
|
1095 |
+
]
|
1096 |
+
},
|
1097 |
+
{
|
1098 |
+
"cell_type": "markdown",
|
1099 |
+
"id": "9babd013",
|
1100 |
+
"metadata": {},
|
1101 |
+
"source": [
|
1102 |
+
"### Testing the model\n",
|
1103 |
+
"This is the same procedure as in `distilbert.ipynb`. "
|
1104 |
+
]
|
1105 |
+
},
|
1106 |
+
{
|
1107 |
+
"cell_type": "code",
|
1108 |
+
"execution_count": 67,
|
1109 |
+
"id": "694c828b",
|
1110 |
+
"metadata": {},
|
1111 |
+
"outputs": [],
|
1112 |
+
"source": [
|
1113 |
+
"# get smaller dataset\n",
|
1114 |
+
"batch_size = 8\n",
|
1115 |
+
"test_ds = Dataset(squad_paths = squad_paths[:2], natural_question_paths=None, hotpotqa_paths=None, tokenizer=tokenizer)\n",
|
1116 |
+
"test_ds_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)\n",
|
1117 |
+
"optim=torch.optim.Adam(model.parameters())"
|
1118 |
+
]
|
1119 |
+
},
|
1120 |
+
{
|
1121 |
+
"cell_type": "code",
|
1122 |
+
"execution_count": 68,
|
1123 |
+
"id": "a76587df",
|
1124 |
+
"metadata": {},
|
1125 |
+
"outputs": [
|
1126 |
+
{
|
1127 |
+
"name": "stdout",
|
1128 |
+
"output_type": "stream",
|
1129 |
+
"text": [
|
1130 |
+
"Passed\n"
|
1131 |
+
]
|
1132 |
+
}
|
1133 |
+
],
|
1134 |
+
"source": [
|
1135 |
+
"test_model(model, optim, test_ds_loader, device)"
|
1136 |
+
]
|
1137 |
+
},
|
1138 |
+
{
|
1139 |
+
"cell_type": "markdown",
|
1140 |
+
"id": "7c326e8e",
|
1141 |
+
"metadata": {},
|
1142 |
+
"source": [
|
1143 |
+
"### Training the model\n",
|
1144 |
+
"* Parameter Tuning:\n",
|
1145 |
+
" * Learning Rate: I experimented with several values, 1e-4 seemed to work best for me. 1e-3 was very unstable and 1e-5 was too small.\n",
|
1146 |
+
" * Gradient Clipping: I experimented with this, but the difference was only minimal\n",
|
1147 |
+
"\n",
|
1148 |
+
"Data:\n",
|
1149 |
+
"* I first used only the SQuAD dataset, but generalisation is a problem\n",
|
1150 |
+
" * The dataset is realtively small and we often have entries with the same context but different questions\n",
|
1151 |
+
" * I believe, the diversity is not big enough to train a fully functional model\n",
|
1152 |
+
"* Hence, I included the Natural Questions dataset too\n",
|
1153 |
+
" * It is however a lot more messy - I elaborated a bit more on this in `load_data.ipynb`\n",
|
1154 |
+
"* Also the hotpotqa data was used\n",
|
1155 |
+
"\n",
|
1156 |
+
"Tested with: \n",
|
1157 |
+
"* 3 Linear Layers\n",
|
1158 |
+
" * Training Error high - needed more layers\n",
|
1159 |
+
" * Already expected - this was mostly a Proof of Concept\n",
|
1160 |
+
"* 1 TransformerEncoder with 4 attention heads + 1 Linear Layer:\n",
|
1161 |
+
" * Training Error was high, still too simple\n",
|
1162 |
+
"* 1 TransformerEncoder with 8 heads + 1 Linear Layer:\n",
|
1163 |
+
" * Training Error gets lower, however stagnates at some point\n",
|
1164 |
+
" * Probably still too simple, it doesn't generalise either\n",
|
1165 |
+
"* 2 TransformerEncoder with 8 and 4 heads + 1 Linear Layer:\n",
|
1166 |
+
" * Loss gets down but doesn't go further after some time\n"
|
1167 |
+
]
|
1168 |
+
},
|
1169 |
+
{
|
1170 |
+
"cell_type": "code",
|
1171 |
+
"execution_count": null,
|
1172 |
+
"id": "2e9f4bd3",
|
1173 |
+
"metadata": {},
|
1174 |
+
"outputs": [],
|
1175 |
+
"source": [
|
1176 |
+
"dataset = Dataset(squad_paths = squad_paths, natural_question_paths=nat_paths, hotpotqa_paths=hotpotqa_paths, tokenizer=tokenizer)\n",
|
1177 |
+
"loader = torch.utils.data.DataLoader(dataset, batch_size=8)\n",
|
1178 |
+
"\n",
|
1179 |
+
"test_dataset = Dataset(squad_paths = [str(x) for x in Path('data/test_squad/').glob('**/*.txt')], \n",
|
1180 |
+
" natural_question_paths=None, \n",
|
1181 |
+
" hotpotqa_paths = None, tokenizer=tokenizer)\n",
|
1182 |
+
"test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4)"
|
1183 |
+
]
|
1184 |
+
},
|
1185 |
+
{
|
1186 |
+
"cell_type": "code",
|
1187 |
+
"execution_count": 26,
|
1188 |
+
"id": "03a6de37",
|
1189 |
+
"metadata": {},
|
1190 |
+
"outputs": [],
|
1191 |
+
"source": [
|
1192 |
+
"model = QuestionDistilBERT(mod)"
|
1193 |
+
]
|
1194 |
+
},
|
1195 |
+
{
|
1196 |
+
"cell_type": "code",
|
1197 |
+
"execution_count": 41,
|
1198 |
+
"id": "ed854b73",
|
1199 |
+
"metadata": {},
|
1200 |
+
"outputs": [],
|
1201 |
+
"source": [
|
1202 |
+
"from torch.optim import AdamW, RMSprop\n",
|
1203 |
+
"\n",
|
1204 |
+
"model.train()\n",
|
1205 |
+
"optim = RMSprop(model.parameters(), lr=1e-4)"
|
1206 |
+
]
|
1207 |
+
},
|
1208 |
+
{
|
1209 |
+
"cell_type": "code",
|
1210 |
+
"execution_count": 42,
|
1211 |
+
"id": "79fdfcc9",
|
1212 |
+
"metadata": {},
|
1213 |
+
"outputs": [],
|
1214 |
+
"source": [
|
1215 |
+
"from torch.utils.tensorboard import SummaryWriter\n",
|
1216 |
+
"writer = SummaryWriter()"
|
1217 |
+
]
|
1218 |
+
},
|
1219 |
+
{
|
1220 |
+
"cell_type": "code",
|
1221 |
+
"execution_count": null,
|
1222 |
+
"id": "f7bddb43",
|
1223 |
+
"metadata": {},
|
1224 |
+
"outputs": [
|
1225 |
+
{
|
1226 |
+
"data": {
|
1227 |
+
"application/vnd.jupyter.widget-view+json": {
|
1228 |
+
"model_id": "5e9e74167c4b4b22b3218f4ca3c5abf0",
|
1229 |
+
"version_major": 2,
|
1230 |
+
"version_minor": 0
|
1231 |
+
},
|
1232 |
+
"text/plain": [
|
1233 |
+
" 0%| | 0/21750 [00:00<?, ?it/s]"
|
1234 |
+
]
|
1235 |
+
},
|
1236 |
+
"metadata": {},
|
1237 |
+
"output_type": "display_data"
|
1238 |
+
},
|
1239 |
+
{
|
1240 |
+
"name": "stdout",
|
1241 |
+
"output_type": "stream",
|
1242 |
+
"text": [
|
1243 |
+
"Mean Training Error 3.8791405910185013\n"
|
1244 |
+
]
|
1245 |
+
},
|
1246 |
+
{
|
1247 |
+
"data": {
|
1248 |
+
"application/vnd.jupyter.widget-view+json": {
|
1249 |
+
"model_id": "f3ce562fc61d4bfc83a4860eb06bc20c",
|
1250 |
+
"version_major": 2,
|
1251 |
+
"version_minor": 0
|
1252 |
+
},
|
1253 |
+
"text/plain": [
|
1254 |
+
" 0%| | 0/1250 [00:00<?, ?it/s]"
|
1255 |
+
]
|
1256 |
+
},
|
1257 |
+
"metadata": {},
|
1258 |
+
"output_type": "display_data"
|
1259 |
+
},
|
1260 |
+
{
|
1261 |
+
"name": "stdout",
|
1262 |
+
"output_type": "stream",
|
1263 |
+
"text": [
|
1264 |
+
"Mean Test Error 3.7705092002868654\n"
|
1265 |
+
]
|
1266 |
+
},
|
1267 |
+
{
|
1268 |
+
"data": {
|
1269 |
+
"application/vnd.jupyter.widget-view+json": {
|
1270 |
+
"model_id": "2e84e21cedd446a0a5f5a40501711d1c",
|
1271 |
+
"version_major": 2,
|
1272 |
+
"version_minor": 0
|
1273 |
+
},
|
1274 |
+
"text/plain": [
|
1275 |
+
" 0%| | 0/21750 [00:00<?, ?it/s]"
|
1276 |
+
]
|
1277 |
+
},
|
1278 |
+
"metadata": {},
|
1279 |
+
"output_type": "display_data"
|
1280 |
+
},
|
1281 |
+
{
|
1282 |
+
"name": "stdout",
|
1283 |
+
"output_type": "stream",
|
1284 |
+
"text": [
|
1285 |
+
"Mean Training Error 3.7389922174091996\n"
|
1286 |
+
]
|
1287 |
+
},
|
1288 |
+
{
|
1289 |
+
"data": {
|
1290 |
+
"application/vnd.jupyter.widget-view+json": {
|
1291 |
+
"model_id": "07135c48be0146498cd37d767c1ee6ab",
|
1292 |
+
"version_major": 2,
|
1293 |
+
"version_minor": 0
|
1294 |
+
},
|
1295 |
+
"text/plain": [
|
1296 |
+
" 0%| | 0/1250 [00:00<?, ?it/s]"
|
1297 |
+
]
|
1298 |
+
},
|
1299 |
+
"metadata": {},
|
1300 |
+
"output_type": "display_data"
|
1301 |
+
},
|
1302 |
+
{
|
1303 |
+
"name": "stdout",
|
1304 |
+
"output_type": "stream",
|
1305 |
+
"text": [
|
1306 |
+
"Mean Test Error 3.7443671816825868\n"
|
1307 |
+
]
|
1308 |
+
},
|
1309 |
+
{
|
1310 |
+
"data": {
|
1311 |
+
"application/vnd.jupyter.widget-view+json": {
|
1312 |
+
"model_id": "e9a51fbabc7043c2819a68e247e4a3ec",
|
1313 |
+
"version_major": 2,
|
1314 |
+
"version_minor": 0
|
1315 |
+
},
|
1316 |
+
"text/plain": [
|
1317 |
+
" 0%| | 0/21750 [00:00<?, ?it/s]"
|
1318 |
+
]
|
1319 |
+
},
|
1320 |
+
"metadata": {},
|
1321 |
+
"output_type": "display_data"
|
1322 |
+
},
|
1323 |
+
{
|
1324 |
+
"name": "stdout",
|
1325 |
+
"output_type": "stream",
|
1326 |
+
"text": [
|
1327 |
+
"Mean Training Error 3.7031057048117977\n"
|
1328 |
+
]
|
1329 |
+
},
|
1330 |
+
{
|
1331 |
+
"data": {
|
1332 |
+
"application/vnd.jupyter.widget-view+json": {
|
1333 |
+
"model_id": "bfdbcc9fe32542a19c47bc1d7704400e",
|
1334 |
+
"version_major": 2,
|
1335 |
+
"version_minor": 0
|
1336 |
+
},
|
1337 |
+
"text/plain": [
|
1338 |
+
" 0%| | 0/1250 [00:00<?, ?it/s]"
|
1339 |
+
]
|
1340 |
+
},
|
1341 |
+
"metadata": {},
|
1342 |
+
"output_type": "display_data"
|
1343 |
+
},
|
1344 |
+
{
|
1345 |
+
"name": "stdout",
|
1346 |
+
"output_type": "stream",
|
1347 |
+
"text": [
|
1348 |
+
"Mean Test Error 3.743248237323761\n"
|
1349 |
+
]
|
1350 |
+
},
|
1351 |
+
{
|
1352 |
+
"data": {
|
1353 |
+
"application/vnd.jupyter.widget-view+json": {
|
1354 |
+
"model_id": "81fd1278b22643dc9fb3ac306533a240",
|
1355 |
+
"version_major": 2,
|
1356 |
+
"version_minor": 0
|
1357 |
+
},
|
1358 |
+
"text/plain": [
|
1359 |
+
" 0%| | 0/21750 [00:00<?, ?it/s]"
|
1360 |
+
]
|
1361 |
+
},
|
1362 |
+
"metadata": {},
|
1363 |
+
"output_type": "display_data"
|
1364 |
+
},
|
1365 |
+
{
|
1366 |
+
"name": "stdout",
|
1367 |
+
"output_type": "stream",
|
1368 |
+
"text": [
|
1369 |
+
"Mean Training Error 3.6711661003430685\n"
|
1370 |
+
]
|
1371 |
+
},
|
1372 |
+
{
|
1373 |
+
"data": {
|
1374 |
+
"application/vnd.jupyter.widget-view+json": {
|
1375 |
+
"model_id": "8b38d6cd44e048ec8bcd6b5cb86cce16",
|
1376 |
+
"version_major": 2,
|
1377 |
+
"version_minor": 0
|
1378 |
+
},
|
1379 |
+
"text/plain": [
|
1380 |
+
" 0%| | 0/1250 [00:00<?, ?it/s]"
|
1381 |
+
]
|
1382 |
+
},
|
1383 |
+
"metadata": {},
|
1384 |
+
"output_type": "display_data"
|
1385 |
+
},
|
1386 |
+
{
|
1387 |
+
"name": "stdout",
|
1388 |
+
"output_type": "stream",
|
1389 |
+
"text": [
|
1390 |
+
"Mean Test Error 3.740310479736328\n"
|
1391 |
+
]
|
1392 |
+
},
|
1393 |
+
{
|
1394 |
+
"data": {
|
1395 |
+
"application/vnd.jupyter.widget-view+json": {
|
1396 |
+
"model_id": "825248aa3f934f4aade9d973e6f3b43e",
|
1397 |
+
"version_major": 2,
|
1398 |
+
"version_minor": 0
|
1399 |
+
},
|
1400 |
+
"text/plain": [
|
1401 |
+
" 0%| | 0/21750 [00:00<?, ?it/s]"
|
1402 |
+
]
|
1403 |
+
},
|
1404 |
+
"metadata": {},
|
1405 |
+
"output_type": "display_data"
|
1406 |
+
},
|
1407 |
+
{
|
1408 |
+
"name": "stdout",
|
1409 |
+
"output_type": "stream",
|
1410 |
+
"text": [
|
1411 |
+
"Mean Training Error 3.6591619139813827\n"
|
1412 |
+
]
|
1413 |
+
},
|
1414 |
+
{
|
1415 |
+
"data": {
|
1416 |
+
"application/vnd.jupyter.widget-view+json": {
|
1417 |
+
"model_id": "edceb7af0ec6450997820967638c12db",
|
1418 |
+
"version_major": 2,
|
1419 |
+
"version_minor": 0
|
1420 |
+
},
|
1421 |
+
"text/plain": [
|
1422 |
+
" 0%| | 0/1250 [00:00<?, ?it/s]"
|
1423 |
+
]
|
1424 |
+
},
|
1425 |
+
"metadata": {},
|
1426 |
+
"output_type": "display_data"
|
1427 |
+
},
|
1428 |
+
{
|
1429 |
+
"name": "stdout",
|
1430 |
+
"output_type": "stream",
|
1431 |
+
"text": [
|
1432 |
+
"Mean Test Error 3.8138498876571654\n"
|
1433 |
+
]
|
1434 |
+
},
|
1435 |
+
{
|
1436 |
+
"data": {
|
1437 |
+
"application/vnd.jupyter.widget-view+json": {
|
1438 |
+
"model_id": "27e903eb0d0f4f949c234e4faf4277a1",
|
1439 |
+
"version_major": 2,
|
1440 |
+
"version_minor": 0
|
1441 |
+
},
|
1442 |
+
"text/plain": [
|
1443 |
+
" 0%| | 0/21750 [00:00<?, ?it/s]"
|
1444 |
+
]
|
1445 |
+
},
|
1446 |
+
"metadata": {},
|
1447 |
+
"output_type": "display_data"
|
1448 |
+
}
|
1449 |
+
],
|
1450 |
+
"source": [
|
1451 |
+
"epochs = 20\n",
|
1452 |
+
"\n",
|
1453 |
+
"for epoch in range(epochs):\n",
|
1454 |
+
" loop = tqdm(loader, leave=True)\n",
|
1455 |
+
" model.train()\n",
|
1456 |
+
" mean_training_error = []\n",
|
1457 |
+
" for batch in loop:\n",
|
1458 |
+
" optim.zero_grad()\n",
|
1459 |
+
" \n",
|
1460 |
+
" input_ids = batch['input_ids'].to(device)\n",
|
1461 |
+
" attention_mask = batch['attention_mask'].to(device)\n",
|
1462 |
+
" start = batch['start_positions'].to(device)\n",
|
1463 |
+
" end = batch['end_positions'].to(device)\n",
|
1464 |
+
" \n",
|
1465 |
+
" outputs = model(input_ids, attention_mask=attention_mask, start_positions=start, end_positions=end)\n",
|
1466 |
+
" \n",
|
1467 |
+
" loss = outputs['loss']\n",
|
1468 |
+
" loss.backward()\n",
|
1469 |
+
" \n",
|
1470 |
+
" optim.step()\n",
|
1471 |
+
" mean_training_error.append(loss.item())\n",
|
1472 |
+
" loop.set_description(f'Epoch {epoch}')\n",
|
1473 |
+
" loop.set_postfix(loss=loss.item())\n",
|
1474 |
+
" print(\"Mean Training Error\", np.mean(mean_training_error))\n",
|
1475 |
+
" writer.add_scalar(\"Loss/train\", np.mean(mean_training_error), epoch)\n",
|
1476 |
+
" \n",
|
1477 |
+
" loop = tqdm(test_loader, leave=True)\n",
|
1478 |
+
" model.eval()\n",
|
1479 |
+
" mean_test_error = []\n",
|
1480 |
+
" for batch in loop:\n",
|
1481 |
+
" \n",
|
1482 |
+
" input_ids = batch['input_ids'].to(device)\n",
|
1483 |
+
" attention_mask = batch['attention_mask'].to(device)\n",
|
1484 |
+
" start = batch['start_positions'].to(device)\n",
|
1485 |
+
" end = batch['end_positions'].to(device)\n",
|
1486 |
+
" \n",
|
1487 |
+
" outputs = model(input_ids, attention_mask=attention_mask, start_positions=start, end_positions=end)\n",
|
1488 |
+
" # print(torch.argmax(outputs['start_logits'],axis=1), torch.argmax(outputs['end_logits'], axis=1), start, end)\n",
|
1489 |
+
" loss = outputs['loss']\n",
|
1490 |
+
" \n",
|
1491 |
+
" mean_test_error.append(loss.item())\n",
|
1492 |
+
" loop.set_description(f'Epoch {epoch} Testset')\n",
|
1493 |
+
" loop.set_postfix(loss=loss.item())\n",
|
1494 |
+
" print(\"Mean Test Error\", np.mean(mean_test_error))\n",
|
1495 |
+
" writer.add_scalar(\"Loss/test\", np.mean(mean_test_error), epoch)"
|
1496 |
+
]
|
1497 |
+
},
|
1498 |
+
{
|
1499 |
+
"cell_type": "code",
|
1500 |
+
"execution_count": 238,
|
1501 |
+
"id": "a9d6af2e",
|
1502 |
+
"metadata": {},
|
1503 |
+
"outputs": [],
|
1504 |
+
"source": [
|
1505 |
+
"writer.close()"
|
1506 |
+
]
|
1507 |
+
},
|
1508 |
+
{
|
1509 |
+
"cell_type": "code",
|
1510 |
+
"execution_count": 33,
|
1511 |
+
"id": "ba43447e",
|
1512 |
+
"metadata": {},
|
1513 |
+
"outputs": [],
|
1514 |
+
"source": [
|
1515 |
+
"torch.save(model.state_dict(), \"distilbert_qa.model\")"
|
1516 |
+
]
|
1517 |
+
},
|
1518 |
+
{
|
1519 |
+
"cell_type": "code",
|
1520 |
+
"execution_count": 34,
|
1521 |
+
"id": "ffc49aca",
|
1522 |
+
"metadata": {},
|
1523 |
+
"outputs": [
|
1524 |
+
{
|
1525 |
+
"data": {
|
1526 |
+
"text/plain": [
|
1527 |
+
"<All keys matched successfully>"
|
1528 |
+
]
|
1529 |
+
},
|
1530 |
+
"execution_count": 34,
|
1531 |
+
"metadata": {},
|
1532 |
+
"output_type": "execute_result"
|
1533 |
+
}
|
1534 |
+
],
|
1535 |
+
"source": [
|
1536 |
+
"model = QuestionDistilBERT(mod)\n",
|
1537 |
+
"model.load_state_dict(torch.load(\"distilbert_qa.model\"))"
|
1538 |
+
]
|
1539 |
+
},
|
1540 |
+
{
|
1541 |
+
"cell_type": "code",
|
1542 |
+
"execution_count": 35,
|
1543 |
+
"id": "730a86c1",
|
1544 |
+
"metadata": {},
|
1545 |
+
"outputs": [
|
1546 |
+
{
|
1547 |
+
"name": "stderr",
|
1548 |
+
"output_type": "stream",
|
1549 |
+
"text": [
|
1550 |
+
"100%|鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅| 2500/2500 [02:57<00:00, 14.09it/s]"
|
1551 |
+
]
|
1552 |
+
},
|
1553 |
+
{
|
1554 |
+
"name": "stdout",
|
1555 |
+
"output_type": "stream",
|
1556 |
+
"text": [
|
1557 |
+
"Mean EM: 0.0479\n",
|
1558 |
+
"Mean F-1: 0.08989175857485086\n"
|
1559 |
+
]
|
1560 |
+
},
|
1561 |
+
{
|
1562 |
+
"name": "stderr",
|
1563 |
+
"output_type": "stream",
|
1564 |
+
"text": [
|
1565 |
+
"\n"
|
1566 |
+
]
|
1567 |
+
}
|
1568 |
+
],
|
1569 |
+
"source": [
|
1570 |
+
"eval_test_set(model, tokenizer, test_loader, device)"
|
1571 |
+
]
|
1572 |
+
},
|
1573 |
+
{
|
1574 |
+
"cell_type": "markdown",
|
1575 |
+
"id": "bd1c7076",
|
1576 |
+
"metadata": {},
|
1577 |
+
"source": [
|
1578 |
+
"## Reuse Layer\n",
|
1579 |
+
"This was inspired by how well the original model with just one classification head worked. I felt like the main problem with the previous model was the lack of structure which was already in the layers, combined with the massive amount of resources needed for a Transformer.\n",
|
1580 |
+
"\n",
|
1581 |
+
"Hence, I tried cloning the last (and then last two) layers of the DistilBERT model, putting a classifier on top and using this as the head. The base DistilBERT model is completely frozen. This worked extremely well, while we only fine-tune about 21% of the parameters (14 Mio as opposed to 66 Mio!) we did before. Below you can see the results.\n",
|
1582 |
+
"\n",
|
1583 |
+
"### Last DistilBERT layer\n",
|
1584 |
+
"\n",
|
1585 |
+
"Dropout 0.1 and RMSprop 1e-4:\n",
|
1586 |
+
"* Mean EM: 0.3888\n",
|
1587 |
+
"* Mean F-1: 0.5122932744694068\n",
|
1588 |
+
"\n",
|
1589 |
+
"Dropout 0.25: very early stagnating\n",
|
1590 |
+
"* Mean EM: 0.3552\n",
|
1591 |
+
"* Mean F-1: 0.4711235721312687\n",
|
1592 |
+
"\n",
|
1593 |
+
"Dropout 0.15: seems to work well - training and test error stagnate around 1.7 and 1.8 but good generalisation (need to add more layers)\n",
|
1594 |
+
"* Mean EM: 0.4119\n",
|
1595 |
+
"* Mean F-1: 0.5296387232893214\n",
|
1596 |
+
"\n",
|
1597 |
+
"### Last DitilBERT layer + more Dense layers\n",
|
1598 |
+
"Dropout 0.15 + 4 dense layers((786-512)-(512-256)-(256-128)-(128-2)) & ReLU: doesn't work too well - stagnates at around 2.4\n",
|
1599 |
+
"\n",
|
1600 |
+
"### Last two DistilBERT layers\n",
|
1601 |
+
"Dropout 0.1 but last 2 DistilBERT layers: works very well, but early overfitting - maybe use more data\n",
|
1602 |
+
"* Mean EM: 0.458\n",
|
1603 |
+
"* Mean F-1: 0.6003368353673634\n",
|
1604 |
+
"\n",
|
1605 |
+
"Dropout 0.1 - last 2 distilbert layers: all data\n",
|
1606 |
+
"* Mean EM: 0.484\n",
|
1607 |
+
"* Mean F-1: 0.6344960035215299\n",
|
1608 |
+
"\n",
|
1609 |
+
"Dropout 0.15 - **BEST**\n",
|
1610 |
+
"* Mean EM: 0.5178\n",
|
1611 |
+
"* Mean F-1: 0.6671140689626448\n",
|
1612 |
+
"\n",
|
1613 |
+
"Dropout 0.2 - doesn't work too well\n",
|
1614 |
+
"* Mean EM: 0.4353\n",
|
1615 |
+
"* Mean F-1: 0.5776847879304647\n"
|
1616 |
+
]
|
1617 |
+
},
|
1618 |
+
{
|
1619 |
+
"cell_type": "code",
|
1620 |
+
"execution_count": 69,
|
1621 |
+
"id": "654e09e8",
|
1622 |
+
"metadata": {},
|
1623 |
+
"outputs": [],
|
1624 |
+
"source": [
|
1625 |
+
"dataset = Dataset(squad_paths = squad_paths, natural_question_paths=None, hotpotqa_paths=hotpotqa_paths, tokenizer=tokenizer)\n",
|
1626 |
+
"loader = torch.utils.data.DataLoader(dataset, batch_size=8)\n",
|
1627 |
+
"\n",
|
1628 |
+
"test_dataset = Dataset(squad_paths = [str(x) for x in Path('data/test_squad/').glob('**/*.txt')], \n",
|
1629 |
+
" natural_question_paths=None, \n",
|
1630 |
+
" hotpotqa_paths = None, tokenizer=tokenizer)\n",
|
1631 |
+
"test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4)"
|
1632 |
+
]
|
1633 |
+
},
|
1634 |
+
{
|
1635 |
+
"cell_type": "code",
|
1636 |
+
"execution_count": 70,
|
1637 |
+
"id": "707c0cb5",
|
1638 |
+
"metadata": {},
|
1639 |
+
"outputs": [
|
1640 |
+
{
|
1641 |
+
"data": {
|
1642 |
+
"text/plain": [
|
1643 |
+
"ReuseQuestionDistilBERT(\n",
|
1644 |
+
" (te): ModuleList(\n",
|
1645 |
+
" (0): TransformerBlock(\n",
|
1646 |
+
" (attention): MultiHeadSelfAttention(\n",
|
1647 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
1648 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
1649 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
1650 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
1651 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
1652 |
+
" )\n",
|
1653 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
1654 |
+
" (ffn): FFN(\n",
|
1655 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
1656 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
1657 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
1658 |
+
" (activation): GELUActivation()\n",
|
1659 |
+
" )\n",
|
1660 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
1661 |
+
" )\n",
|
1662 |
+
" (1): TransformerBlock(\n",
|
1663 |
+
" (attention): MultiHeadSelfAttention(\n",
|
1664 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
1665 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
1666 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
1667 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
1668 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
1669 |
+
" )\n",
|
1670 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
1671 |
+
" (ffn): FFN(\n",
|
1672 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
1673 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
1674 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
1675 |
+
" (activation): GELUActivation()\n",
|
1676 |
+
" )\n",
|
1677 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
1678 |
+
" )\n",
|
1679 |
+
" )\n",
|
1680 |
+
" (distilbert): DistilBertModel(\n",
|
1681 |
+
" (embeddings): Embeddings(\n",
|
1682 |
+
" (word_embeddings): Embedding(30522, 768, padding_idx=0)\n",
|
1683 |
+
" (position_embeddings): Embedding(512, 768)\n",
|
1684 |
+
" (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
1685 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
1686 |
+
" )\n",
|
1687 |
+
" (transformer): Transformer(\n",
|
1688 |
+
" (layer): ModuleList(\n",
|
1689 |
+
" (0): TransformerBlock(\n",
|
1690 |
+
" (attention): MultiHeadSelfAttention(\n",
|
1691 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
1692 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
1693 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
1694 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
1695 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
1696 |
+
" )\n",
|
1697 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
1698 |
+
" (ffn): FFN(\n",
|
1699 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
1700 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
1701 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
1702 |
+
" (activation): GELUActivation()\n",
|
1703 |
+
" )\n",
|
1704 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
1705 |
+
" )\n",
|
1706 |
+
" (1): TransformerBlock(\n",
|
1707 |
+
" (attention): MultiHeadSelfAttention(\n",
|
1708 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
1709 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
1710 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
1711 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
1712 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
1713 |
+
" )\n",
|
1714 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
1715 |
+
" (ffn): FFN(\n",
|
1716 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
1717 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
1718 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
1719 |
+
" (activation): GELUActivation()\n",
|
1720 |
+
" )\n",
|
1721 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
1722 |
+
" )\n",
|
1723 |
+
" (2): TransformerBlock(\n",
|
1724 |
+
" (attention): MultiHeadSelfAttention(\n",
|
1725 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
1726 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
1727 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
1728 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
1729 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
1730 |
+
" )\n",
|
1731 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
1732 |
+
" (ffn): FFN(\n",
|
1733 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
1734 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
1735 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
1736 |
+
" (activation): GELUActivation()\n",
|
1737 |
+
" )\n",
|
1738 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
1739 |
+
" )\n",
|
1740 |
+
" (3): TransformerBlock(\n",
|
1741 |
+
" (attention): MultiHeadSelfAttention(\n",
|
1742 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
1743 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
1744 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
1745 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
1746 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
1747 |
+
" )\n",
|
1748 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
1749 |
+
" (ffn): FFN(\n",
|
1750 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
1751 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
1752 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
1753 |
+
" (activation): GELUActivation()\n",
|
1754 |
+
" )\n",
|
1755 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
1756 |
+
" )\n",
|
1757 |
+
" (4): TransformerBlock(\n",
|
1758 |
+
" (attention): MultiHeadSelfAttention(\n",
|
1759 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
1760 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
1761 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
1762 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
1763 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
1764 |
+
" )\n",
|
1765 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
1766 |
+
" (ffn): FFN(\n",
|
1767 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
1768 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
1769 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
1770 |
+
" (activation): GELUActivation()\n",
|
1771 |
+
" )\n",
|
1772 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
1773 |
+
" )\n",
|
1774 |
+
" (5): TransformerBlock(\n",
|
1775 |
+
" (attention): MultiHeadSelfAttention(\n",
|
1776 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
1777 |
+
" (q_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
1778 |
+
" (k_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
1779 |
+
" (v_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
1780 |
+
" (out_lin): Linear(in_features=768, out_features=768, bias=True)\n",
|
1781 |
+
" )\n",
|
1782 |
+
" (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
1783 |
+
" (ffn): FFN(\n",
|
1784 |
+
" (dropout): Dropout(p=0.1, inplace=False)\n",
|
1785 |
+
" (lin1): Linear(in_features=768, out_features=3072, bias=True)\n",
|
1786 |
+
" (lin2): Linear(in_features=3072, out_features=768, bias=True)\n",
|
1787 |
+
" (activation): GELUActivation()\n",
|
1788 |
+
" )\n",
|
1789 |
+
" (output_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
|
1790 |
+
" )\n",
|
1791 |
+
" )\n",
|
1792 |
+
" )\n",
|
1793 |
+
" )\n",
|
1794 |
+
" (relu): ReLU()\n",
|
1795 |
+
" (dropout): Dropout(p=0.15, inplace=False)\n",
|
1796 |
+
" (classifier): Linear(in_features=768, out_features=2, bias=True)\n",
|
1797 |
+
")"
|
1798 |
+
]
|
1799 |
+
},
|
1800 |
+
"execution_count": 70,
|
1801 |
+
"metadata": {},
|
1802 |
+
"output_type": "execute_result"
|
1803 |
+
}
|
1804 |
+
],
|
1805 |
+
"source": [
|
1806 |
+
"model = DistilBertForMaskedLM.from_pretrained(\"distilbert-base-uncased\")\n",
|
1807 |
+
"config = DistilBertConfig.from_pretrained(\"distilbert-base-uncased\")\n",
|
1808 |
+
"mod = model.distilbert\n",
|
1809 |
+
"\n",
|
1810 |
+
"device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')\n",
|
1811 |
+
"model = ReuseQuestionDistilBERT(mod)\n",
|
1812 |
+
"model.to(device)"
|
1813 |
+
]
|
1814 |
+
},
|
1815 |
+
{
|
1816 |
+
"cell_type": "code",
|
1817 |
+
"execution_count": 71,
|
1818 |
+
"id": "d2c6bff5",
|
1819 |
+
"metadata": {},
|
1820 |
+
"outputs": [
|
1821 |
+
{
|
1822 |
+
"name": "stdout",
|
1823 |
+
"output_type": "stream",
|
1824 |
+
"text": [
|
1825 |
+
"+-------------------------------+------------+\n",
|
1826 |
+
"| Modules | Parameters |\n",
|
1827 |
+
"+-------------------------------+------------+\n",
|
1828 |
+
"| te.0.attention.q_lin.weight | 589824 |\n",
|
1829 |
+
"| te.0.attention.q_lin.bias | 768 |\n",
|
1830 |
+
"| te.0.attention.k_lin.weight | 589824 |\n",
|
1831 |
+
"| te.0.attention.k_lin.bias | 768 |\n",
|
1832 |
+
"| te.0.attention.v_lin.weight | 589824 |\n",
|
1833 |
+
"| te.0.attention.v_lin.bias | 768 |\n",
|
1834 |
+
"| te.0.attention.out_lin.weight | 589824 |\n",
|
1835 |
+
"| te.0.attention.out_lin.bias | 768 |\n",
|
1836 |
+
"| te.0.sa_layer_norm.weight | 768 |\n",
|
1837 |
+
"| te.0.sa_layer_norm.bias | 768 |\n",
|
1838 |
+
"| te.0.ffn.lin1.weight | 2359296 |\n",
|
1839 |
+
"| te.0.ffn.lin1.bias | 3072 |\n",
|
1840 |
+
"| te.0.ffn.lin2.weight | 2359296 |\n",
|
1841 |
+
"| te.0.ffn.lin2.bias | 768 |\n",
|
1842 |
+
"| te.0.output_layer_norm.weight | 768 |\n",
|
1843 |
+
"| te.0.output_layer_norm.bias | 768 |\n",
|
1844 |
+
"| te.1.attention.q_lin.weight | 589824 |\n",
|
1845 |
+
"| te.1.attention.q_lin.bias | 768 |\n",
|
1846 |
+
"| te.1.attention.k_lin.weight | 589824 |\n",
|
1847 |
+
"| te.1.attention.k_lin.bias | 768 |\n",
|
1848 |
+
"| te.1.attention.v_lin.weight | 589824 |\n",
|
1849 |
+
"| te.1.attention.v_lin.bias | 768 |\n",
|
1850 |
+
"| te.1.attention.out_lin.weight | 589824 |\n",
|
1851 |
+
"| te.1.attention.out_lin.bias | 768 |\n",
|
1852 |
+
"| te.1.sa_layer_norm.weight | 768 |\n",
|
1853 |
+
"| te.1.sa_layer_norm.bias | 768 |\n",
|
1854 |
+
"| te.1.ffn.lin1.weight | 2359296 |\n",
|
1855 |
+
"| te.1.ffn.lin1.bias | 3072 |\n",
|
1856 |
+
"| te.1.ffn.lin2.weight | 2359296 |\n",
|
1857 |
+
"| te.1.ffn.lin2.bias | 768 |\n",
|
1858 |
+
"| te.1.output_layer_norm.weight | 768 |\n",
|
1859 |
+
"| te.1.output_layer_norm.bias | 768 |\n",
|
1860 |
+
"| classifier.weight | 1536 |\n",
|
1861 |
+
"| classifier.bias | 2 |\n",
|
1862 |
+
"+-------------------------------+------------+\n",
|
1863 |
+
"Total Trainable Params: 14177282\n"
|
1864 |
+
]
|
1865 |
+
},
|
1866 |
+
{
|
1867 |
+
"data": {
|
1868 |
+
"text/plain": [
|
1869 |
+
"14177282"
|
1870 |
+
]
|
1871 |
+
},
|
1872 |
+
"execution_count": 71,
|
1873 |
+
"metadata": {},
|
1874 |
+
"output_type": "execute_result"
|
1875 |
+
}
|
1876 |
+
],
|
1877 |
+
"source": [
|
1878 |
+
"count_parameters(model)"
|
1879 |
+
]
|
1880 |
+
},
|
1881 |
+
{
|
1882 |
+
"cell_type": "markdown",
|
1883 |
+
"id": "c386c2eb",
|
1884 |
+
"metadata": {},
|
1885 |
+
"source": [
|
1886 |
+
"### Testing the Model"
|
1887 |
+
]
|
1888 |
+
},
|
1889 |
+
{
|
1890 |
+
"cell_type": "code",
|
1891 |
+
"execution_count": 72,
|
1892 |
+
"id": "818deed3",
|
1893 |
+
"metadata": {},
|
1894 |
+
"outputs": [],
|
1895 |
+
"source": [
|
1896 |
+
"# get smaller dataset\n",
|
1897 |
+
"batch_size = 8\n",
|
1898 |
+
"test_ds = Dataset(squad_paths = squad_paths[:2], natural_question_paths=None, hotpotqa_paths=None, tokenizer=tokenizer)\n",
|
1899 |
+
"test_ds_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)\n",
|
1900 |
+
"optim=torch.optim.Adam(model.parameters())"
|
1901 |
+
]
|
1902 |
+
},
|
1903 |
+
{
|
1904 |
+
"cell_type": "code",
|
1905 |
+
"execution_count": 73,
|
1906 |
+
"id": "9da40760",
|
1907 |
+
"metadata": {},
|
1908 |
+
"outputs": [
|
1909 |
+
{
|
1910 |
+
"name": "stdout",
|
1911 |
+
"output_type": "stream",
|
1912 |
+
"text": [
|
1913 |
+
"Passed\n"
|
1914 |
+
]
|
1915 |
+
}
|
1916 |
+
],
|
1917 |
+
"source": [
|
1918 |
+
"test_model(model, optim, test_ds_loader, device)"
|
1919 |
+
]
|
1920 |
+
},
|
1921 |
+
{
|
1922 |
+
"cell_type": "markdown",
|
1923 |
+
"id": "c3f80248",
|
1924 |
+
"metadata": {},
|
1925 |
+
"source": [
|
1926 |
+
"### Model Training"
|
1927 |
+
]
|
1928 |
+
},
|
1929 |
+
{
|
1930 |
+
"cell_type": "code",
|
1931 |
+
"execution_count": 24,
|
1932 |
+
"id": "e1adabe6",
|
1933 |
+
"metadata": {},
|
1934 |
+
"outputs": [],
|
1935 |
+
"source": [
|
1936 |
+
"from torch.optim import AdamW, RMSprop\n",
|
1937 |
+
"\n",
|
1938 |
+
"model.train()\n",
|
1939 |
+
"optim = AdamW(model.parameters(), lr=1e-4)"
|
1940 |
+
]
|
1941 |
+
},
|
1942 |
+
{
|
1943 |
+
"cell_type": "code",
|
1944 |
+
"execution_count": 25,
|
1945 |
+
"id": "efe1cbd5",
|
1946 |
+
"metadata": {},
|
1947 |
+
"outputs": [
|
1948 |
+
{
|
1949 |
+
"data": {
|
1950 |
+
"application/vnd.jupyter.widget-view+json": {
|
1951 |
+
"model_id": "8785757b04214102830ded36c1392c8d",
|
1952 |
+
"version_major": 2,
|
1953 |
+
"version_minor": 0
|
1954 |
+
},
|
1955 |
+
"text/plain": [
|
1956 |
+
" 0%| | 0/35000 [00:00<?, ?it/s]"
|
1957 |
+
]
|
1958 |
+
},
|
1959 |
+
"metadata": {},
|
1960 |
+
"output_type": "display_data"
|
1961 |
+
},
|
1962 |
+
{
|
1963 |
+
"name": "stdout",
|
1964 |
+
"output_type": "stream",
|
1965 |
+
"text": [
|
1966 |
+
"Mean Training Error 2.6535016193100383\n"
|
1967 |
+
]
|
1968 |
+
},
|
1969 |
+
{
|
1970 |
+
"data": {
|
1971 |
+
"application/vnd.jupyter.widget-view+json": {
|
1972 |
+
"model_id": "836f5365498642fa9ae891a86dca5892",
|
1973 |
+
"version_major": 2,
|
1974 |
+
"version_minor": 0
|
1975 |
+
},
|
1976 |
+
"text/plain": [
|
1977 |
+
" 0%| | 0/2500 [00:00<?, ?it/s]"
|
1978 |
+
]
|
1979 |
+
},
|
1980 |
+
"metadata": {},
|
1981 |
+
"output_type": "display_data"
|
1982 |
+
},
|
1983 |
+
{
|
1984 |
+
"name": "stdout",
|
1985 |
+
"output_type": "stream",
|
1986 |
+
"text": [
|
1987 |
+
"Mean Test Error 2.384517493388057\n"
|
1988 |
+
]
|
1989 |
+
},
|
1990 |
+
{
|
1991 |
+
"data": {
|
1992 |
+
"application/vnd.jupyter.widget-view+json": {
|
1993 |
+
"model_id": "981e1cef83a1477e920d1cdbffdfcde1",
|
1994 |
+
"version_major": 2,
|
1995 |
+
"version_minor": 0
|
1996 |
+
},
|
1997 |
+
"text/plain": [
|
1998 |
+
" 0%| | 0/35000 [00:00<?, ?it/s]"
|
1999 |
+
]
|
2000 |
+
},
|
2001 |
+
"metadata": {},
|
2002 |
+
"output_type": "display_data"
|
2003 |
+
},
|
2004 |
+
{
|
2005 |
+
"name": "stdout",
|
2006 |
+
"output_type": "stream",
|
2007 |
+
"text": [
|
2008 |
+
"Mean Training Error 2.172889394424643\n"
|
2009 |
+
]
|
2010 |
+
},
|
2011 |
+
{
|
2012 |
+
"data": {
|
2013 |
+
"application/vnd.jupyter.widget-view+json": {
|
2014 |
+
"model_id": "20a785e7fefb43239f1120992d2c3416",
|
2015 |
+
"version_major": 2,
|
2016 |
+
"version_minor": 0
|
2017 |
+
},
|
2018 |
+
"text/plain": [
|
2019 |
+
" 0%| | 0/2500 [00:00<?, ?it/s]"
|
2020 |
+
]
|
2021 |
+
},
|
2022 |
+
"metadata": {},
|
2023 |
+
"output_type": "display_data"
|
2024 |
+
},
|
2025 |
+
{
|
2026 |
+
"name": "stdout",
|
2027 |
+
"output_type": "stream",
|
2028 |
+
"text": [
|
2029 |
+
"Mean Test Error 2.013008696398139\n"
|
2030 |
+
]
|
2031 |
+
},
|
2032 |
+
{
|
2033 |
+
"data": {
|
2034 |
+
"application/vnd.jupyter.widget-view+json": {
|
2035 |
+
"model_id": "47831e65b1ed4be78e8e7cb24068b0c3",
|
2036 |
+
"version_major": 2,
|
2037 |
+
"version_minor": 0
|
2038 |
+
},
|
2039 |
+
"text/plain": [
|
2040 |
+
" 0%| | 0/35000 [00:00<?, ?it/s]"
|
2041 |
+
]
|
2042 |
+
},
|
2043 |
+
"metadata": {},
|
2044 |
+
"output_type": "display_data"
|
2045 |
+
},
|
2046 |
+
{
|
2047 |
+
"name": "stdout",
|
2048 |
+
"output_type": "stream",
|
2049 |
+
"text": [
|
2050 |
+
"Mean Training Error 1.9743544759827\n"
|
2051 |
+
]
|
2052 |
+
},
|
2053 |
+
{
|
2054 |
+
"data": {
|
2055 |
+
"application/vnd.jupyter.widget-view+json": {
|
2056 |
+
"model_id": "15904a3f930249fb944ea87184676e14",
|
2057 |
+
"version_major": 2,
|
2058 |
+
"version_minor": 0
|
2059 |
+
},
|
2060 |
+
"text/plain": [
|
2061 |
+
" 0%| | 0/2500 [00:00<?, ?it/s]"
|
2062 |
+
]
|
2063 |
+
},
|
2064 |
+
"metadata": {},
|
2065 |
+
"output_type": "display_data"
|
2066 |
+
},
|
2067 |
+
{
|
2068 |
+
"name": "stdout",
|
2069 |
+
"output_type": "stream",
|
2070 |
+
"text": [
|
2071 |
+
"Mean Test Error 1.8922049684919418\n"
|
2072 |
+
]
|
2073 |
+
},
|
2074 |
+
{
|
2075 |
+
"data": {
|
2076 |
+
"application/vnd.jupyter.widget-view+json": {
|
2077 |
+
"model_id": "108bdbf644d94d78910195992b9e2652",
|
2078 |
+
"version_major": 2,
|
2079 |
+
"version_minor": 0
|
2080 |
+
},
|
2081 |
+
"text/plain": [
|
2082 |
+
" 0%| | 0/35000 [00:00<?, ?it/s]"
|
2083 |
+
]
|
2084 |
+
},
|
2085 |
+
"metadata": {},
|
2086 |
+
"output_type": "display_data"
|
2087 |
+
},
|
2088 |
+
{
|
2089 |
+
"name": "stdout",
|
2090 |
+
"output_type": "stream",
|
2091 |
+
"text": [
|
2092 |
+
"Mean Training Error 1.857202093189742\n"
|
2093 |
+
]
|
2094 |
+
},
|
2095 |
+
{
|
2096 |
+
"data": {
|
2097 |
+
"application/vnd.jupyter.widget-view+json": {
|
2098 |
+
"model_id": "d6a75a6ab40d4a2599b7511bfc60bf83",
|
2099 |
+
"version_major": 2,
|
2100 |
+
"version_minor": 0
|
2101 |
+
},
|
2102 |
+
"text/plain": [
|
2103 |
+
" 0%| | 0/2500 [00:00<?, ?it/s]"
|
2104 |
+
]
|
2105 |
+
},
|
2106 |
+
"metadata": {},
|
2107 |
+
"output_type": "display_data"
|
2108 |
+
},
|
2109 |
+
{
|
2110 |
+
"name": "stdout",
|
2111 |
+
"output_type": "stream",
|
2112 |
+
"text": [
|
2113 |
+
"Mean Test Error 1.793771461571753\n"
|
2114 |
+
]
|
2115 |
+
},
|
2116 |
+
{
|
2117 |
+
"data": {
|
2118 |
+
"application/vnd.jupyter.widget-view+json": {
|
2119 |
+
"model_id": "d3468a6ba72a4f42b0e7cc77ee0a0011",
|
2120 |
+
"version_major": 2,
|
2121 |
+
"version_minor": 0
|
2122 |
+
},
|
2123 |
+
"text/plain": [
|
2124 |
+
" 0%| | 0/35000 [00:00<?, ?it/s]"
|
2125 |
+
]
|
2126 |
+
},
|
2127 |
+
"metadata": {},
|
2128 |
+
"output_type": "display_data"
|
2129 |
+
},
|
2130 |
+
{
|
2131 |
+
"name": "stdout",
|
2132 |
+
"output_type": "stream",
|
2133 |
+
"text": [
|
2134 |
+
"Mean Training Error 1.7750537034896867\n"
|
2135 |
+
]
|
2136 |
+
},
|
2137 |
+
{
|
2138 |
+
"data": {
|
2139 |
+
"application/vnd.jupyter.widget-view+json": {
|
2140 |
+
"model_id": "8aca0aa529d2452e8bd29fe7ada934f2",
|
2141 |
+
"version_major": 2,
|
2142 |
+
"version_minor": 0
|
2143 |
+
},
|
2144 |
+
"text/plain": [
|
2145 |
+
" 0%| | 0/2500 [00:00<?, ?it/s]"
|
2146 |
+
]
|
2147 |
+
},
|
2148 |
+
"metadata": {},
|
2149 |
+
"output_type": "display_data"
|
2150 |
+
},
|
2151 |
+
{
|
2152 |
+
"name": "stdout",
|
2153 |
+
"output_type": "stream",
|
2154 |
+
"text": [
|
2155 |
+
"Mean Test Error 1.7466133671954274\n"
|
2156 |
+
]
|
2157 |
+
},
|
2158 |
+
{
|
2159 |
+
"data": {
|
2160 |
+
"application/vnd.jupyter.widget-view+json": {
|
2161 |
+
"model_id": "e09abdfa63c841ce97f445ba9b3eeaa8",
|
2162 |
+
"version_major": 2,
|
2163 |
+
"version_minor": 0
|
2164 |
+
},
|
2165 |
+
"text/plain": [
|
2166 |
+
" 0%| | 0/35000 [00:00<?, ?it/s]"
|
2167 |
+
]
|
2168 |
+
},
|
2169 |
+
"metadata": {},
|
2170 |
+
"output_type": "display_data"
|
2171 |
+
},
|
2172 |
+
{
|
2173 |
+
"name": "stdout",
|
2174 |
+
"output_type": "stream",
|
2175 |
+
"text": [
|
2176 |
+
"Mean Training Error 1.7097622096568346\n"
|
2177 |
+
]
|
2178 |
+
},
|
2179 |
+
{
|
2180 |
+
"data": {
|
2181 |
+
"application/vnd.jupyter.widget-view+json": {
|
2182 |
+
"model_id": "0f49dd32d33e4f398be0942a59d735ce",
|
2183 |
+
"version_major": 2,
|
2184 |
+
"version_minor": 0
|
2185 |
+
},
|
2186 |
+
"text/plain": [
|
2187 |
+
" 0%| | 0/2500 [00:00<?, ?it/s]"
|
2188 |
+
]
|
2189 |
+
},
|
2190 |
+
"metadata": {},
|
2191 |
+
"output_type": "display_data"
|
2192 |
+
},
|
2193 |
+
{
|
2194 |
+
"name": "stdout",
|
2195 |
+
"output_type": "stream",
|
2196 |
+
"text": [
|
2197 |
+
"Mean Test Error 1.7642206047609448\n"
|
2198 |
+
]
|
2199 |
+
},
|
2200 |
+
{
|
2201 |
+
"data": {
|
2202 |
+
"application/vnd.jupyter.widget-view+json": {
|
2203 |
+
"model_id": "a493dd70ffb64cd19830e5dc98607979",
|
2204 |
+
"version_major": 2,
|
2205 |
+
"version_minor": 0
|
2206 |
+
},
|
2207 |
+
"text/plain": [
|
2208 |
+
" 0%| | 0/35000 [00:00<?, ?it/s]"
|
2209 |
+
]
|
2210 |
+
},
|
2211 |
+
"metadata": {},
|
2212 |
+
"output_type": "display_data"
|
2213 |
+
},
|
2214 |
+
{
|
2215 |
+
"name": "stderr",
|
2216 |
+
"output_type": "stream",
|
2217 |
+
"text": [
|
2218 |
+
"\n",
|
2219 |
+
"KeyboardInterrupt\n",
|
2220 |
+
"\n"
|
2221 |
+
]
|
2222 |
+
}
|
2223 |
+
],
|
2224 |
+
"source": [
|
2225 |
+
"epochs = 16\n",
|
2226 |
+
"\n",
|
2227 |
+
"for epoch in range(epochs):\n",
|
2228 |
+
" loop = tqdm(loader, leave=True)\n",
|
2229 |
+
" model.train()\n",
|
2230 |
+
" mean_training_error = []\n",
|
2231 |
+
" for batch in loop:\n",
|
2232 |
+
" optim.zero_grad()\n",
|
2233 |
+
" \n",
|
2234 |
+
" input_ids = batch['input_ids'].to(device)\n",
|
2235 |
+
" attention_mask = batch['attention_mask'].to(device)\n",
|
2236 |
+
" start = batch['start_positions'].to(device)\n",
|
2237 |
+
" end = batch['end_positions'].to(device)\n",
|
2238 |
+
" \n",
|
2239 |
+
" outputs = model(input_ids, attention_mask=attention_mask, start_positions=start, end_positions=end)\n",
|
2240 |
+
" # print(torch.argmax(outputs['start_logits'],axis=1), torch.argmax(outputs['end_logits'], axis=1), start, end)\n",
|
2241 |
+
" loss = outputs['loss']\n",
|
2242 |
+
" loss.backward()\n",
|
2243 |
+
" # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)\n",
|
2244 |
+
" optim.step()\n",
|
2245 |
+
" mean_training_error.append(loss.item())\n",
|
2246 |
+
" loop.set_description(f'Epoch {epoch}')\n",
|
2247 |
+
" loop.set_postfix(loss=loss.item())\n",
|
2248 |
+
" print(\"Mean Training Error\", np.mean(mean_training_error))\n",
|
2249 |
+
" \n",
|
2250 |
+
" loop = tqdm(test_loader, leave=True)\n",
|
2251 |
+
" model.eval()\n",
|
2252 |
+
" mean_test_error = []\n",
|
2253 |
+
" for batch in loop:\n",
|
2254 |
+
" \n",
|
2255 |
+
" input_ids = batch['input_ids'].to(device)\n",
|
2256 |
+
" attention_mask = batch['attention_mask'].to(device)\n",
|
2257 |
+
" start = batch['start_positions'].to(device)\n",
|
2258 |
+
" end = batch['end_positions'].to(device)\n",
|
2259 |
+
" \n",
|
2260 |
+
" outputs = model(input_ids, attention_mask=attention_mask, start_positions=start, end_positions=end)\n",
|
2261 |
+
" # print(torch.argmax(outputs['start_logits'],axis=1), torch.argmax(outputs['end_logits'], axis=1), start, end)\n",
|
2262 |
+
" loss = outputs['loss']\n",
|
2263 |
+
" \n",
|
2264 |
+
" mean_test_error.append(loss.item())\n",
|
2265 |
+
" loop.set_description(f'Epoch {epoch} Testset')\n",
|
2266 |
+
" loop.set_postfix(loss=loss.item())\n",
|
2267 |
+
" print(\"Mean Test Error\", np.mean(mean_test_error))\n",
|
2268 |
+
" torch.save(model.state_dict(), \"distilbert_reuse_{}\".format(epoch))"
|
2269 |
+
]
|
2270 |
+
},
|
2271 |
+
{
|
2272 |
+
"cell_type": "code",
|
2273 |
+
"execution_count": 48,
|
2274 |
+
"id": "fdf37d18",
|
2275 |
+
"metadata": {},
|
2276 |
+
"outputs": [],
|
2277 |
+
"source": [
|
2278 |
+
"torch.save(model.state_dict(), \"distilbert_reuse.model\")"
|
2279 |
+
]
|
2280 |
+
},
|
2281 |
+
{
|
2282 |
+
"cell_type": "code",
|
2283 |
+
"execution_count": 49,
|
2284 |
+
"id": "d1cfded4",
|
2285 |
+
"metadata": {},
|
2286 |
+
"outputs": [],
|
2287 |
+
"source": [
|
2288 |
+
"m = ReuseQuestionDistilBERT(mod)\n",
|
2289 |
+
"m.load_state_dict(torch.load(\"distilbert_reuse.model\"))\n",
|
2290 |
+
"model = m"
|
2291 |
+
]
|
2292 |
+
},
|
2293 |
+
{
|
2294 |
+
"cell_type": "code",
|
2295 |
+
"execution_count": 47,
|
2296 |
+
"id": "233bdc18",
|
2297 |
+
"metadata": {},
|
2298 |
+
"outputs": [
|
2299 |
+
{
|
2300 |
+
"name": "stderr",
|
2301 |
+
"output_type": "stream",
|
2302 |
+
"text": [
|
2303 |
+
"100%|鈻堚枅鈻堚枅鈻堚枅鈻堚枅鈻堚枅| 2500/2500 [02:51<00:00, 14.59it/s]"
|
2304 |
+
]
|
2305 |
+
},
|
2306 |
+
{
|
2307 |
+
"name": "stdout",
|
2308 |
+
"output_type": "stream",
|
2309 |
+
"text": [
|
2310 |
+
"Mean EM: 0.5178\n",
|
2311 |
+
"Mean F-1: 0.6671140689626448\n"
|
2312 |
+
]
|
2313 |
+
},
|
2314 |
+
{
|
2315 |
+
"name": "stderr",
|
2316 |
+
"output_type": "stream",
|
2317 |
+
"text": [
|
2318 |
+
"\n"
|
2319 |
+
]
|
2320 |
+
}
|
2321 |
+
],
|
2322 |
+
"source": [
|
2323 |
+
"eval_test_set(model, tokenizer, test_loader, device)"
|
2324 |
+
]
|
2325 |
+
},
|
2326 |
+
{
|
2327 |
+
"cell_type": "code",
|
2328 |
+
"execution_count": null,
|
2329 |
+
"id": "0fb1ce9e",
|
2330 |
+
"metadata": {},
|
2331 |
+
"outputs": [],
|
2332 |
+
"source": []
|
2333 |
+
}
|
2334 |
+
],
|
2335 |
+
"metadata": {
|
2336 |
+
"kernelspec": {
|
2337 |
+
"display_name": "Python 3.10.8 ('venv': venv)",
|
2338 |
+
"language": "python",
|
2339 |
+
"name": "python3"
|
2340 |
+
},
|
2341 |
+
"language_info": {
|
2342 |
+
"codemirror_mode": {
|
2343 |
+
"name": "ipython",
|
2344 |
+
"version": 3
|
2345 |
+
},
|
2346 |
+
"file_extension": ".py",
|
2347 |
+
"mimetype": "text/x-python",
|
2348 |
+
"name": "python",
|
2349 |
+
"nbconvert_exporter": "python",
|
2350 |
+
"pygments_lexer": "ipython3",
|
2351 |
+
"version": "3.10.8"
|
2352 |
+
},
|
2353 |
+
"toc": {
|
2354 |
+
"base_numbering": 1,
|
2355 |
+
"nav_menu": {},
|
2356 |
+
"number_sections": true,
|
2357 |
+
"sideBar": true,
|
2358 |
+
"skip_h1_title": false,
|
2359 |
+
"title_cell": "Table of Contents",
|
2360 |
+
"title_sidebar": "Contents",
|
2361 |
+
"toc_cell": false,
|
2362 |
+
"toc_position": {},
|
2363 |
+
"toc_section_display": true,
|
2364 |
+
"toc_window_display": false
|
2365 |
+
},
|
2366 |
+
"varInspector": {
|
2367 |
+
"cols": {
|
2368 |
+
"lenName": 16,
|
2369 |
+
"lenType": 16,
|
2370 |
+
"lenVar": 40
|
2371 |
+
},
|
2372 |
+
"kernels_config": {
|
2373 |
+
"python": {
|
2374 |
+
"delete_cmd_postfix": "",
|
2375 |
+
"delete_cmd_prefix": "del ",
|
2376 |
+
"library": "var_list.py",
|
2377 |
+
"varRefreshCmd": "print(var_dic_list())"
|
2378 |
+
},
|
2379 |
+
"r": {
|
2380 |
+
"delete_cmd_postfix": ") ",
|
2381 |
+
"delete_cmd_prefix": "rm(",
|
2382 |
+
"library": "var_list.r",
|
2383 |
+
"varRefreshCmd": "cat(var_dic_list()) "
|
2384 |
+
}
|
2385 |
+
},
|
2386 |
+
"types_to_exclude": [
|
2387 |
+
"module",
|
2388 |
+
"function",
|
2389 |
+
"builtin_function_or_method",
|
2390 |
+
"instance",
|
2391 |
+
"_Feature"
|
2392 |
+
],
|
2393 |
+
"window_display": false
|
2394 |
+
},
|
2395 |
+
"vscode": {
|
2396 |
+
"interpreter": {
|
2397 |
+
"hash": "85bf9c14e9ba73b783ed1274d522bec79eb0b2b739090180d8ce17bb11aff4aa"
|
2398 |
+
}
|
2399 |
+
}
|
2400 |
+
},
|
2401 |
+
"nbformat": 4,
|
2402 |
+
"nbformat_minor": 5
|
2403 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py==1.3.0
|
2 |
+
aiohttp==3.8.3
|
3 |
+
aiosignal==1.2.0
|
4 |
+
altair==4.2.0
|
5 |
+
apache-beam>=2.41.0
|
6 |
+
argon2-cffi==21.3.0
|
7 |
+
argon2-cffi-bindings==21.2.0
|
8 |
+
asttokens==2.0.8
|
9 |
+
async-timeout==4.0.2
|
10 |
+
attrs==22.1.0
|
11 |
+
autopep8==1.7.0
|
12 |
+
backcall==0.2.0
|
13 |
+
beautifulsoup4==4.11.1
|
14 |
+
bleach==5.0.1
|
15 |
+
blinker==1.5
|
16 |
+
cachetools==5.2.0
|
17 |
+
certifi==2022.9.24
|
18 |
+
cffi==1.15.1
|
19 |
+
charset-normalizer==2.1.1
|
20 |
+
click==8.1.3
|
21 |
+
cloudpickle==2.2.0
|
22 |
+
commonmark==0.9.1
|
23 |
+
contourpy==1.0.5
|
24 |
+
crcmod==1.7
|
25 |
+
cycler==0.11.0
|
26 |
+
datasets==2.5.2
|
27 |
+
debugpy==1.6.3
|
28 |
+
decorator==5.1.1
|
29 |
+
defusedxml==0.7.1
|
30 |
+
dill==0.3.1.1
|
31 |
+
docopt==0.6.2
|
32 |
+
entrypoints==0.4
|
33 |
+
executing==1.1.0
|
34 |
+
fastavro==1.6.1
|
35 |
+
fastjsonschema==2.16.2
|
36 |
+
filelock==3.8.0
|
37 |
+
fonttools==4.37.4
|
38 |
+
frozenlist==1.3.1
|
39 |
+
fsspec==2022.8.2
|
40 |
+
gitdb==4.0.10
|
41 |
+
GitPython==3.1.29
|
42 |
+
google-auth==2.13.0
|
43 |
+
google-auth-oauthlib==0.4.6
|
44 |
+
grpcio==1.49.1
|
45 |
+
hdfs==2.7.0
|
46 |
+
httplib2==0.20.4
|
47 |
+
huggingface-hub==0.10.0
|
48 |
+
idna==3.4
|
49 |
+
importlib-metadata==5.1.0
|
50 |
+
ipykernel==6.16.0
|
51 |
+
ipython==8.5.0
|
52 |
+
ipython-genutils==0.2.0
|
53 |
+
ipywidgets==8.0.2
|
54 |
+
jedi==0.18.1
|
55 |
+
Jinja2==3.1.2
|
56 |
+
joblib==1.2.0
|
57 |
+
jsonschema==4.16.0
|
58 |
+
jupyter==1.0.0
|
59 |
+
jupyter-console==6.4.4
|
60 |
+
jupyter-contrib-core==0.4.0
|
61 |
+
jupyter-contrib-nbextensions==0.5.1
|
62 |
+
jupyter-highlight-selected-word==0.2.0
|
63 |
+
jupyter-latex-envs==1.4.6
|
64 |
+
jupyter-nbextensions-configurator==0.5.0
|
65 |
+
jupyter_client==7.3.5
|
66 |
+
jupyter_core==4.11.2
|
67 |
+
jupyterlab-pygments==0.2.2
|
68 |
+
jupyterlab-widgets==3.0.3
|
69 |
+
kiwisolver==1.4.4
|
70 |
+
lesscpy==0.15.1
|
71 |
+
lxml==4.9.1
|
72 |
+
Markdown==3.4.1
|
73 |
+
MarkupSafe==2.1.1
|
74 |
+
matplotlib==3.6.1
|
75 |
+
matplotlib-inline==0.1.6
|
76 |
+
mistune==2.0.4
|
77 |
+
multidict==6.0.2
|
78 |
+
multiprocess==0.70.9
|
79 |
+
mwparserfromhell==0.6.4
|
80 |
+
nbclient==0.7.0
|
81 |
+
nbconvert==7.2.1
|
82 |
+
nbformat==5.6.1
|
83 |
+
nest-asyncio==1.5.6
|
84 |
+
notebook==6.4.12
|
85 |
+
numpy==1.22.4
|
86 |
+
oauthlib==3.2.2
|
87 |
+
orjson==3.9.7
|
88 |
+
packaging==21.3
|
89 |
+
pandas==1.5.0
|
90 |
+
pandocfilters==1.5.0
|
91 |
+
parso==0.8.3
|
92 |
+
pexpect==4.8.0
|
93 |
+
pickleshare==0.7.5
|
94 |
+
Pillow==9.2.0
|
95 |
+
ply==3.11
|
96 |
+
prettytable==3.4.1
|
97 |
+
prometheus-client==0.14.1
|
98 |
+
prompt-toolkit==3.0.31
|
99 |
+
proto-plus==1.22.1
|
100 |
+
protobuf==3.19.6
|
101 |
+
psutil==5.9.2
|
102 |
+
ptyprocess==0.7.0
|
103 |
+
pure-eval==0.2.2
|
104 |
+
pyarrow==7.0.0
|
105 |
+
pyasn1==0.4.8
|
106 |
+
pyasn1-modules==0.2.8
|
107 |
+
pycodestyle==2.9.1
|
108 |
+
pycparser==2.21
|
109 |
+
pydeck==0.8.0
|
110 |
+
pydot==1.4.2
|
111 |
+
Pygments==2.13.0
|
112 |
+
pymongo==3.12.3
|
113 |
+
Pympler==1.0.1
|
114 |
+
pyparsing==3.0.9
|
115 |
+
pyrsistent==0.18.1
|
116 |
+
python-dateutil==2.8.2
|
117 |
+
pytz==2022.4
|
118 |
+
pytz-deprecation-shim==0.1.0.post0
|
119 |
+
PyYAML==6.0
|
120 |
+
pyzmq==24.0.1
|
121 |
+
qtconsole==5.3.2
|
122 |
+
QtPy==2.2.1
|
123 |
+
regex==2022.9.13
|
124 |
+
requests==2.28.1
|
125 |
+
requests-oauthlib==1.3.1
|
126 |
+
responses==0.18.0
|
127 |
+
rich==12.6.0
|
128 |
+
rsa==4.9
|
129 |
+
scikit-learn==1.1.2
|
130 |
+
scipy==1.9.1
|
131 |
+
semver==2.13.0
|
132 |
+
Send2Trash==1.8.0
|
133 |
+
six==1.16.0
|
134 |
+
smmap==5.0.0
|
135 |
+
soupsieve==2.3.2.post1
|
136 |
+
stack-data==0.5.1
|
137 |
+
streamlit==1.15.2
|
138 |
+
tensorboard==2.10.1
|
139 |
+
tensorboard-data-server==0.6.1
|
140 |
+
tensorboard-plugin-wit==1.8.1
|
141 |
+
terminado==0.16.0
|
142 |
+
threadpoolctl==3.1.0
|
143 |
+
tinycss2==1.1.1
|
144 |
+
tokenizers==0.12.1
|
145 |
+
toml==0.10.2
|
146 |
+
toolz==0.12.0
|
147 |
+
torch==1.12.1
|
148 |
+
torchaudio==0.12.1
|
149 |
+
torchsummary==1.5.1
|
150 |
+
torchtest==0.5
|
151 |
+
torchvision==0.13.1
|
152 |
+
tornado==6.2
|
153 |
+
tqdm==4.64.1
|
154 |
+
traitlets==5.4.0
|
155 |
+
transformers==4.22.2
|
156 |
+
typing_extensions==4.4.0
|
157 |
+
tzdata==2022.7
|
158 |
+
tzlocal==4.2
|
159 |
+
urllib3==1.26.12
|
160 |
+
validators==0.20.0
|
161 |
+
watchdog==2.2.0
|
162 |
+
wcwidth==0.2.5
|
163 |
+
webencodings==0.5.1
|
164 |
+
Werkzeug==2.2.2
|
165 |
+
widgetsnbextension==4.0.3
|
166 |
+
xxhash==3.0.0
|
167 |
+
yarl==1.8.1
|
168 |
+
zipp==3.11.0
|
util.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import numpy as np
|
3 |
+
from prettytable import PrettyTable
|
4 |
+
from tqdm import tqdm
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
def normalize_text(s):
|
9 |
+
"""
|
10 |
+
Removes articles and punctuation, and standardizing whitespace are all typical text processing steps.
|
11 |
+
Copied from: https://qa.fastforwardlabs.com/no%20answer/null%20threshold/bert/distilbert/exact%20match/f1/robust%20predictions/2020/06/09/Evaluating_BERT_on_SQuAD.html#Metrics-for-QA
|
12 |
+
:param s: string to clean
|
13 |
+
:return: cleaned string
|
14 |
+
"""
|
15 |
+
import string, re
|
16 |
+
|
17 |
+
def remove_articles(text):
|
18 |
+
regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
|
19 |
+
return re.sub(regex, " ", text)
|
20 |
+
|
21 |
+
def white_space_fix(text):
|
22 |
+
return " ".join(text.split())
|
23 |
+
|
24 |
+
def remove_punc(text):
|
25 |
+
exclude = set(string.punctuation)
|
26 |
+
return "".join(ch for ch in text if ch not in exclude)
|
27 |
+
|
28 |
+
def lower(text):
|
29 |
+
return text.lower()
|
30 |
+
|
31 |
+
return white_space_fix(remove_articles(remove_punc(lower(s))))
|
32 |
+
|
33 |
+
|
34 |
+
def compute_exact_match(prediction, truth):
|
35 |
+
"""
|
36 |
+
Returns true if the predicted is an exact match, else False
|
37 |
+
Retrieved from: https://qa.fastforwardlabs.com/no%20answer/null%20threshold/bert/distilbert/exact%20match/f1/robust%20predictions/2020/06/09/Evaluating_BERT_on_SQuAD.html#Metrics-for-QA
|
38 |
+
:param prediction: predicted answer
|
39 |
+
:param truth: ground truth
|
40 |
+
:return: 1 if exact match, else 0
|
41 |
+
"""
|
42 |
+
return int(normalize_text(prediction) == normalize_text(truth))
|
43 |
+
|
44 |
+
|
45 |
+
def compute_f1(prediction, truth):
|
46 |
+
"""
|
47 |
+
Computes the F-1 score of a prediction, based on the tokens
|
48 |
+
Retrieved from: https://qa.fastforwardlabs.com/no%20answer/null%20threshold/bert/distilbert/exact%20match/f1/robust%20predictions/2020/06/09/Evaluating_BERT_on_SQuAD.html#Metrics-for-QA
|
49 |
+
:param prediction: predicted answer
|
50 |
+
:param truth: ground truth
|
51 |
+
:return: the f-1 score of the prediction
|
52 |
+
"""
|
53 |
+
pred_tokens = normalize_text(prediction).split()
|
54 |
+
truth_tokens = normalize_text(truth).split()
|
55 |
+
|
56 |
+
# if either the prediction or the truth is no-answer then f1 = 1 if they agree, 0 otherwise
|
57 |
+
if len(pred_tokens) == 0 or len(truth_tokens) == 0:
|
58 |
+
return int(pred_tokens == truth_tokens)
|
59 |
+
|
60 |
+
# get tokens that are in the prediction and gt
|
61 |
+
common_tokens = set(pred_tokens) & set(truth_tokens)
|
62 |
+
|
63 |
+
# if there are no common tokens then f1 = 0
|
64 |
+
if len(common_tokens) == 0:
|
65 |
+
return 0
|
66 |
+
|
67 |
+
# calculate precision and recall
|
68 |
+
prec = len(common_tokens) / len(pred_tokens)
|
69 |
+
rec = len(common_tokens) / len(truth_tokens)
|
70 |
+
|
71 |
+
return 2 * (prec * rec) / (prec + rec)
|
72 |
+
|
73 |
+
def eval_test_set(model, tokenizer, test_loader, device):
|
74 |
+
"""
|
75 |
+
Calculates the mean EM and mean F-1 score on the test set
|
76 |
+
:param model: pytorch model
|
77 |
+
:param tokenizer: tokenizer used to encode the samples
|
78 |
+
:param test_loader: dataloader object with test data
|
79 |
+
:param device: device the model is on
|
80 |
+
"""
|
81 |
+
mean_em = []
|
82 |
+
mean_f1 = []
|
83 |
+
model.to(device)
|
84 |
+
model.eval()
|
85 |
+
for batch in tqdm(test_loader):
|
86 |
+
# get test data and transfer to device
|
87 |
+
input_ids = batch['input_ids'].to(device)
|
88 |
+
attention_mask = batch['attention_mask'].to(device)
|
89 |
+
start = batch['start_positions'].to(device)
|
90 |
+
end = batch['end_positions'].to(device)
|
91 |
+
|
92 |
+
# predict
|
93 |
+
outputs = model(input_ids, attention_mask=attention_mask, start_positions=start, end_positions=end)
|
94 |
+
|
95 |
+
# iterate over samples, calculate EM and F-1 for all
|
96 |
+
for input_i, s, e, trues, truee in zip(input_ids, outputs['start_logits'], outputs['end_logits'], start, end):
|
97 |
+
# get predicted start and end logits (maximum score)
|
98 |
+
start_logits = torch.argmax(s)
|
99 |
+
end_logits = torch.argmax(e)
|
100 |
+
|
101 |
+
# get predicted answer as string
|
102 |
+
ans_tokens = input_i[start_logits: end_logits + 1]
|
103 |
+
answer_tokens = tokenizer.convert_ids_to_tokens(ans_tokens, skip_special_tokens=True)
|
104 |
+
predicted = tokenizer.convert_tokens_to_string(answer_tokens)
|
105 |
+
|
106 |
+
# get ground truth as string
|
107 |
+
ans_tokens = input_i[trues: truee + 1]
|
108 |
+
answer_tokens = tokenizer.convert_ids_to_tokens(ans_tokens, skip_special_tokens=True)
|
109 |
+
true = tokenizer.convert_tokens_to_string(answer_tokens)
|
110 |
+
|
111 |
+
# compute score
|
112 |
+
em_score = compute_exact_match(predicted, true)
|
113 |
+
f1_score = compute_f1(predicted, true)
|
114 |
+
mean_em.append(em_score)
|
115 |
+
mean_f1.append(f1_score)
|
116 |
+
print("Mean EM: ", np.mean(mean_em))
|
117 |
+
print("Mean F-1: ", np.mean(mean_f1))
|
118 |
+
|
119 |
+
def count_parameters(model):
|
120 |
+
"""
|
121 |
+
This function prints statistic regarding the trainable parameters
|
122 |
+
:param model: pytorch model
|
123 |
+
:return: parameters to be fine-tuned
|
124 |
+
"""
|
125 |
+
table = PrettyTable(["Modules", "Parameters"])
|
126 |
+
total_params = 0
|
127 |
+
for name, parameter in model.named_parameters():
|
128 |
+
if not parameter.requires_grad: continue
|
129 |
+
params = parameter.numel()
|
130 |
+
table.add_row([name, params])
|
131 |
+
total_params += params
|
132 |
+
print(table)
|
133 |
+
print(f"Total Trainable Params: {total_params}")
|
134 |
+
return total_params
|