Spaces:
Sleeping
Sleeping
from pickle import dump, load | |
from numpy import argsort, dot, array, any, isin | |
from matplotlib import pyplot as plt | |
from gossip_semantic_search.models import ProcessedDataset | |
from gossip_semantic_search.utils import CustomUnpickler | |
class Evaluator: | |
def __init__(self, | |
processed_dataset_path: str, | |
k_max:int = None, | |
path_k_plot:str = None): | |
self.processed_dataset_path = processed_dataset_path | |
self.k_max = k_max | |
self.path_k_plot = path_k_plot | |
self.processed_dataset: ProcessedDataset = None | |
def load_dataset(self): | |
with open(self.processed_dataset_path, 'rb') as file: | |
unpickler = CustomUnpickler(file) | |
return unpickler.load() | |
def evaluate_k_context(self, | |
k:int) -> float: | |
similarity_matrix = dot(self.processed_dataset.embedded_queries, self.processed_dataset.embedded_context.T) | |
k_best_context = argsort(similarity_matrix, axis=1)[:, -k:][:, ::-1] | |
result = array([1 if any(isin(row, val)) else 0 | |
for row, val in zip(self.processed_dataset.y_true, k_best_context)]) | |
return result.mean() | |
def plot_top_k_graph(self): | |
accuracies = [] | |
for k in range(1, self.k_max): | |
accuracies.append(self.evaluate_k_context(k)) | |
plt.plot(accuracies, marker='o', color='b', linestyle='-', markersize=8, linewidth=2) | |
# Personnalisation du graphique | |
plt.title("Taux de contextes corrects parmi les k meilleurs contextes", fontsize=14, fontweight='bold') | |
plt.xlabel("k", fontsize=12) | |
plt.ylabel("Accuracy", fontsize=12) | |
plt.grid(True, which='both', linestyle='--', linewidth=0.5) | |
plt.xticks(fontsize=10) | |
plt.yticks(fontsize=10) | |
plt.legend() | |
plt.savefig(self.path_k_plot) | |
def run(self): | |
self.processed_dataset = self.load_dataset() | |
if self.k_max: | |
self.plot_top_k_graph() | |