File size: 2,010 Bytes
3ff674d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from pickle import dump, load

from numpy import argsort, dot, array, any, isin
from matplotlib import pyplot as plt
from gossip_semantic_search.models import ProcessedDataset
from gossip_semantic_search.utils import CustomUnpickler


class Evaluator:
    def __init__(self,
                 processed_dataset_path: str,
                 k_max:int = None,
                 path_k_plot:str = None):

        self.processed_dataset_path = processed_dataset_path
        self.k_max = k_max
        self.path_k_plot = path_k_plot

        self.processed_dataset: ProcessedDataset = None

    def load_dataset(self):
        with open(self.processed_dataset_path, 'rb') as file:
            unpickler = CustomUnpickler(file)
            return unpickler.load()

    def evaluate_k_context(self,
                           k:int) -> float:
        similarity_matrix = dot(self.processed_dataset.embedded_queries, self.processed_dataset.embedded_context.T)
        k_best_context = argsort(similarity_matrix, axis=1)[:, -k:][:, ::-1]
        result = array([1 if any(isin(row, val)) else 0
                        for row, val in zip(self.processed_dataset.y_true, k_best_context)])
        return result.mean()

    def plot_top_k_graph(self):
        accuracies = []
        for k in range(1, self.k_max):
            accuracies.append(self.evaluate_k_context(k))
        plt.plot(accuracies, marker='o', color='b', linestyle='-', markersize=8, linewidth=2)

        # Personnalisation du graphique
        plt.title("Taux de contextes corrects parmi les k meilleurs contextes", fontsize=14, fontweight='bold')
        plt.xlabel("k", fontsize=12)
        plt.ylabel("Accuracy", fontsize=12)
        plt.grid(True, which='both', linestyle='--', linewidth=0.5)
        plt.xticks(fontsize=10)
        plt.yticks(fontsize=10)
        plt.legend()
        plt.savefig(self.path_k_plot)

    def run(self):
        self.processed_dataset = self.load_dataset()
        if self.k_max:
            self.plot_top_k_graph()