Spaces:
Sleeping
Sleeping
File size: 2,010 Bytes
3ff674d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
from pickle import dump, load
from numpy import argsort, dot, array, any, isin
from matplotlib import pyplot as plt
from gossip_semantic_search.models import ProcessedDataset
from gossip_semantic_search.utils import CustomUnpickler
class Evaluator:
def __init__(self,
processed_dataset_path: str,
k_max:int = None,
path_k_plot:str = None):
self.processed_dataset_path = processed_dataset_path
self.k_max = k_max
self.path_k_plot = path_k_plot
self.processed_dataset: ProcessedDataset = None
def load_dataset(self):
with open(self.processed_dataset_path, 'rb') as file:
unpickler = CustomUnpickler(file)
return unpickler.load()
def evaluate_k_context(self,
k:int) -> float:
similarity_matrix = dot(self.processed_dataset.embedded_queries, self.processed_dataset.embedded_context.T)
k_best_context = argsort(similarity_matrix, axis=1)[:, -k:][:, ::-1]
result = array([1 if any(isin(row, val)) else 0
for row, val in zip(self.processed_dataset.y_true, k_best_context)])
return result.mean()
def plot_top_k_graph(self):
accuracies = []
for k in range(1, self.k_max):
accuracies.append(self.evaluate_k_context(k))
plt.plot(accuracies, marker='o', color='b', linestyle='-', markersize=8, linewidth=2)
# Personnalisation du graphique
plt.title("Taux de contextes corrects parmi les k meilleurs contextes", fontsize=14, fontweight='bold')
plt.xlabel("k", fontsize=12)
plt.ylabel("Accuracy", fontsize=12)
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt.legend()
plt.savefig(self.path_k_plot)
def run(self):
self.processed_dataset = self.load_dataset()
if self.k_max:
self.plot_top_k_graph()
|