pierre Brault
imit
3ff674d
raw
history blame contribute delete
2.01 kB
from pickle import dump, load
from numpy import argsort, dot, array, any, isin
from matplotlib import pyplot as plt
from gossip_semantic_search.models import ProcessedDataset
from gossip_semantic_search.utils import CustomUnpickler
class Evaluator:
def __init__(self,
processed_dataset_path: str,
k_max:int = None,
path_k_plot:str = None):
self.processed_dataset_path = processed_dataset_path
self.k_max = k_max
self.path_k_plot = path_k_plot
self.processed_dataset: ProcessedDataset = None
def load_dataset(self):
with open(self.processed_dataset_path, 'rb') as file:
unpickler = CustomUnpickler(file)
return unpickler.load()
def evaluate_k_context(self,
k:int) -> float:
similarity_matrix = dot(self.processed_dataset.embedded_queries, self.processed_dataset.embedded_context.T)
k_best_context = argsort(similarity_matrix, axis=1)[:, -k:][:, ::-1]
result = array([1 if any(isin(row, val)) else 0
for row, val in zip(self.processed_dataset.y_true, k_best_context)])
return result.mean()
def plot_top_k_graph(self):
accuracies = []
for k in range(1, self.k_max):
accuracies.append(self.evaluate_k_context(k))
plt.plot(accuracies, marker='o', color='b', linestyle='-', markersize=8, linewidth=2)
# Personnalisation du graphique
plt.title("Taux de contextes corrects parmi les k meilleurs contextes", fontsize=14, fontweight='bold')
plt.xlabel("k", fontsize=12)
plt.ylabel("Accuracy", fontsize=12)
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt.legend()
plt.savefig(self.path_k_plot)
def run(self):
self.processed_dataset = self.load_dataset()
if self.k_max:
self.plot_top_k_graph()