Robert
commited on
Commit
·
b7158e7
1
Parent(s):
9889a50
- Remove useless paragraphs that only contain formulas
Browse files- Added some code to run the script over all questions to calculate overall performance
- main.py +36 -2
- src/retrievers/faiss_retriever.py +3 -0
- src/utils/preprocessing.py +19 -0
main.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import os
|
2 |
import random
|
3 |
from typing import cast
|
|
|
4 |
|
5 |
import torch
|
6 |
import transformers
|
@@ -32,8 +33,8 @@ if __name__ == '__main__':
|
|
32 |
"GroNLP/ik-nlp-22_slp", "paragraphs"))
|
33 |
|
34 |
# Initialize retriever
|
35 |
-
|
36 |
-
retriever = ESRetriever(dataset_paragraphs)
|
37 |
|
38 |
# Retrieve example
|
39 |
# random.seed(111)
|
@@ -84,3 +85,36 @@ if __name__ == '__main__':
|
|
84 |
f"Predicted answer: {answers[highest_index].text}\n"
|
85 |
f"Exact match: {exact_match:.02f}\n"
|
86 |
f"F1-score: {f1_score:.02f}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os
|
2 |
import random
|
3 |
from typing import cast
|
4 |
+
import time
|
5 |
|
6 |
import torch
|
7 |
import transformers
|
|
|
33 |
"GroNLP/ik-nlp-22_slp", "paragraphs"))
|
34 |
|
35 |
# Initialize retriever
|
36 |
+
retriever = FaissRetriever(dataset_paragraphs)
|
37 |
+
#retriever = ESRetriever(dataset_paragraphs)
|
38 |
|
39 |
# Retrieve example
|
40 |
# random.seed(111)
|
|
|
85 |
f"Predicted answer: {answers[highest_index].text}\n"
|
86 |
f"Exact match: {exact_match:.02f}\n"
|
87 |
f"F1-score: {f1_score:.02f}")
|
88 |
+
|
89 |
+
# Calculate overall performance
|
90 |
+
# total_f1 = 0
|
91 |
+
# total_exact = 0
|
92 |
+
# total_len = len(questions_test["question"])
|
93 |
+
# start_time = time.time()
|
94 |
+
# for i, question in enumerate(questions_test["question"]):
|
95 |
+
# print(question)
|
96 |
+
# answer = questions_test["answer"][i]
|
97 |
+
# print(answer)
|
98 |
+
#
|
99 |
+
# scores, result = retriever.retrieve(question)
|
100 |
+
# reader_input = result_to_reader_input(result)
|
101 |
+
# answers = reader.read(question, reader_input)
|
102 |
+
#
|
103 |
+
# document_scores = sm(torch.Tensor(
|
104 |
+
# [pred.relevance_score for pred in answers]))
|
105 |
+
# span_scores = sm(torch.Tensor(
|
106 |
+
# [pred.span_score for pred in answers]))
|
107 |
+
#
|
108 |
+
# highest, highest_index = 0, 0
|
109 |
+
# for j, value in enumerate(span_scores):
|
110 |
+
# if value + document_scores[j] > highest:
|
111 |
+
# highest = value + document_scores[j]
|
112 |
+
# highest_index = j
|
113 |
+
# print(answers[highest_index])
|
114 |
+
# exact_match, f1_score = evaluate(answer, answers[highest_index].text)
|
115 |
+
# total_f1 += f1_score
|
116 |
+
# total_exact += exact_match
|
117 |
+
# print(f"Total time:", round(time.time() - start_time, 2), "seconds.")
|
118 |
+
# print(total_f1)
|
119 |
+
# print(total_exact)
|
120 |
+
# print(total_f1/total_len)
|
src/retrievers/faiss_retriever.py
CHANGED
@@ -12,6 +12,7 @@ from transformers import (
|
|
12 |
|
13 |
from src.retrievers.base_retriever import Retriever
|
14 |
from src.utils.log import get_logger
|
|
|
15 |
|
16 |
# Hacky fix for FAISS error on macOS
|
17 |
# See https://stackoverflow.com/a/63374568/4545692
|
@@ -55,6 +56,8 @@ class FaissRetriever(Retriever):
|
|
55 |
force_new_embedding: bool = False):
|
56 |
|
57 |
ds = self.dataset["train"]
|
|
|
|
|
58 |
|
59 |
if not force_new_embedding and os.path.exists(self.embedding_path):
|
60 |
ds.load_faiss_index(
|
|
|
12 |
|
13 |
from src.retrievers.base_retriever import Retriever
|
14 |
from src.utils.log import get_logger
|
15 |
+
from src.utils.preprocessing import remove_formulas
|
16 |
|
17 |
# Hacky fix for FAISS error on macOS
|
18 |
# See https://stackoverflow.com/a/63374568/4545692
|
|
|
56 |
force_new_embedding: bool = False):
|
57 |
|
58 |
ds = self.dataset["train"]
|
59 |
+
ds = ds.map(remove_formulas)
|
60 |
+
|
61 |
|
62 |
if not force_new_embedding and os.path.exists(self.embedding_path):
|
63 |
ds.load_faiss_index(
|
src/utils/preprocessing.py
CHANGED
@@ -33,3 +33,22 @@ def result_to_reader_input(result: Dict[str, List[str]]) \
|
|
33 |
reader_result['texts'].append(result['text'][n])
|
34 |
|
35 |
return reader_result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
reader_result['texts'].append(result['text'][n])
|
34 |
|
35 |
return reader_result
|
36 |
+
|
37 |
+
|
38 |
+
def remove_formulas(ds):
|
39 |
+
"""Replaces text in the 'text' column of the ds which has an average
|
40 |
+
word length of <= 3.5 with blanks. This essentially means that most
|
41 |
+
of the formulas are removed.
|
42 |
+
To-do:
|
43 |
+
- more-preprocessing
|
44 |
+
- a summarization model perhaps
|
45 |
+
Args:
|
46 |
+
ds: HuggingFace dataset that contains the information for the retriever
|
47 |
+
Returns:
|
48 |
+
ds: preprocessed HuggingFace dataset
|
49 |
+
"""
|
50 |
+
words = ds['text'].split()
|
51 |
+
average = sum(len(word) for word in words) / len(words)
|
52 |
+
if average <= 3.5:
|
53 |
+
ds['text'] = ''
|
54 |
+
return ds
|