pierre Brault commited on
Commit
3ff674d
·
1 Parent(s): f0648ac
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ venv/
2
+ .gradio/
3
+ .idea/
4
+ set_env.sh
5
+ datasets/
6
+ explo.ipynb
7
+ *.pyc
8
+ .DS_Store
README.md CHANGED
@@ -1,12 +1,159 @@
 
 
 
 
 
 
 
 
 
 
1
  ---
2
- title: Gossip Semantic Search Demo
3
- emoji: 🚀
4
- colorFrom: purple
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 5.1.0
8
- app_file: app.py
9
- pinned: false
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Gossip Answering
2
+
3
+ ## Quick Start
4
+
5
+ Cette application utilise **Gradio** pour créer une interface permettant de poser des questions sur des articles provenant de [VSD](https://vsd.fr/) et [Public](https://www.public.fr/) en utilisant des techniques de **Retrieval-Augmented Generation (RAG)**.
6
+
7
+ Ce projet utilise deux services API, et il vous sera nécessaire de générer des clés pour accéder à ces API. :
8
+ - **[Cohere](https://dashboard.cohere.com/api-keys)** pour générer des embeddings.
9
+ - **[Groq](https://console.groq.com/keys)** pour générer du contenu.
10
+
11
  ---
12
+ ### colab ###
13
+
14
+ Vous pouvez accéder au [notebook disponible sur Google Colab ici.](https://colab.research.google.com/drive/1pG4-s2Tg-9o7U0bRZrV1WELTy2k42qk5?usp=drive_link)
15
+
16
+
17
+ Dans la première cellule, vous devrez entrer votre GitHub token pour cloner le repo ainsi que les clés des deux API utilisées par ce projet.
18
+
19
+ Une fois cela fait, exécutez toutes les cellules du notebook, puis rendez-vous sur le lien généré `.gradio.live` à la fin de la dernière cellule pour accéder à l'application.
20
+
21
+ ---
22
+ ### Local ###
23
+ #### Prérequis
24
+
25
+ - **Python 3.12.3** (La version utilisée pour ce projet)
26
+
27
+ #### Installation
28
+
29
+ 1. **Clonez le dépôt :**
30
+
31
+ Clonez le projet sur votre machine locale en utilisant la commande suivante :
32
+
33
+ ```
34
+ git clone https://github.com/PierreBdesG/gossip-semantic-search.git
35
+ ```
36
+
37
+ 2. **Récupérez les clés API :**
38
+
39
+ Définissez vos API keys comme variables d'environnement :
40
+
41
+ ```
42
+ export COHERE_API_KEY=your_cohere_api_key
43
+ export GROQ_API_KEY=your_groq_api_key
44
+ ```
45
+
46
+ 3. **Installez les dépendances :**
47
+
48
+ Installez les bibliothèques nécessaires à l'aide du fichier `requirements.txt` :
49
+
50
+ ```
51
+ pip install -r requirements.txt
52
+ ```
53
+
54
+ 4. **Téléchargez le dataset :**
55
+
56
+ Téléchargez le dataset (mis à jour le 19 octobre 2024) :
57
+ ```
58
+ !curl -L -o dataset.pkl "https://drive.google.com/uc?export=download&id=1g4-pNY3CTRIW5_hinWBREMkTrsZ05GNd"
59
+ ```
60
+
61
+ 5. **Lancez l'application :**
62
+
63
+ Exécutez l'application avec la commande suivante, en fournissant le chemin vers le dataset téléchargé :
64
+
65
+ ```
66
+ python gossip_semantic_search/front --dataset_path dataset.pkl
67
+ ```
68
+
69
  ---
70
 
71
+ ## Création du Dataset
72
+
73
+ Le dataset utilisé dans ce projet est une **`List[Article]`**, où chaque élément est un objet `Article` contenant les informations suivantes :
74
+
75
+ ```
76
+ class Article:
77
+ author: str
78
+ title: str
79
+ link: str
80
+ description: str
81
+ published_date: datetime
82
+ content: str
83
+ embeded_content: NDArray
84
+ questions: list[str]
85
+ ```
86
+
87
+ - Les attributs **`author`**, **`title`**, **`link`**, **`description`**, **`published_date`** et **`content`** sont extraits des flux RSS des sites [VSD](https://vsd.fr/feed) et [Public](https://www.public.fr/feed).
88
+ - Le champ **`embeded_content`** correspond aux embeddings générés pour chaque article. Une description détaillée de l'embedding utilisé est disponible [ici](https://cohere.com/blog/introducing-embed-v3).
89
+ - Le champ **`questions`** contient une liste de `n` questions associées à l'article. Ces questions ont été générées à l'aide de [`Llama-3-70B`](https://huggingface.co/meta-llama/Meta-Llama-3-70B), via l'API fournie par [Groq](https://groq.com/).
90
+
91
+ ---
92
+
93
+ ## Évaluation ##
94
+
95
+ Pour ce dataset, nous avons choisi de générer 3 questions par article. Avec un total de 60 articles, cela constitue un dataset d'environ 180 questions.
96
+
97
+ Une fois ce dataset créé, nous pouvons évaluer notre modèle de retrieval de la manière suivante :
98
+ - Effectuer un produit scalaire (similaire a un cosine similarity vue que les normes des embedding=1) entre tous les embeddings des questions et tous les embeddings des contextes.
99
+ - Pour chaque question, récupérer les `k` meilleurs contextes correspondant.
100
+
101
+ ce qui nous donne les resultats suivant:
102
+ ![k_plot_graph.png](gossip_semantic_search/evaluator/k_plot_graph.png)
103
+
104
+ /!\ regénerer le graph en décalant les absisse de 1
105
+
106
+ #### Résultats du Graph
107
+
108
+ Sur ce graphique, on observe les résultats suivants pour la récupération des contextes :
109
+
110
+ - Si l'on se base uniquement sur le **dot product** le plus élevé pour chaque question, dans **91.5% des cas**, le contexte d'origine est retrouvé.
111
+ - En élargissant la sélection aux **2 meilleurs contextes**, dans environ **99% des cas**, le contexte d'origine se trouve parmi les deux contextes sélectionnés.
112
+ - En prenant les **3 meilleurs contextes**, on atteint une précision de **99.5%**, où le contexte d'origine se retrouve dans les trois contextes sélectionnés.
113
+
114
+
115
+ #### Limitation ####
116
+
117
+ Ce dataset est un dataset **synthétique**. Les questions ont été générées par le modèle **[Llama-3-70B](https://huggingface.co/meta-llama/Meta-Llama-3-70B)**, en prenant en compte un certain article spécifique pour chaque question. Les questions sont donc conçues pour correspondre directement au contenu des articles, créant ainsi un ensemble de données biaisé par apport au réel avec
118
+ des questions bien alignées avec les informations présentes dans le texte.
119
+ Lors de l'inférence, les questions posées par les utilisateurs ne seront pas toujours aussi structurées ou spécifiques que celles générées pour ce dataset. Les questions réelles, posées par des utilisateurs humains, seront souvent plus organiques, plus variées, et moins directement liées à un contexte particulier
120
+ En conséquence, les performances du modèle seront donc moins élevées lorsqu'il sera confronté à des questions plus naturelles
121
+
122
+ ---
123
+
124
+ ## Inférence et génération de réponses ##
125
+
126
+ 1.Une question est entrée par un utilisateur.
127
+
128
+ 2. La question est convertie en embedding.
129
+
130
+ 3. On effectue un dot product entre l'embedding de la question et tous les articles embeddés, qui ont été chargés en RAM lors du lancement de l'application.
131
+
132
+ 4. On récupère les 3 articles les plus pertinents.
133
+
134
+ 5. On prend le premier article et on demande à **[Llama-3-70B](https://huggingface.co/meta-llama/Meta-Llama-3-70B)** si la réponse à la question se trouve dans le contexte de cet article.
135
+ - Si oui, on demande à **Llama-3-70B** de générer la réponse, puis on renvoie la réponse, le lien de l'article et le contexte correspondant.
136
+ - Si non, on passe au deuxième article et répète la même procédure.
137
+ - Si la réponse n'est trouvée dans aucun des 3 articles, on renvoie le message **"incapable de générer une réponse"** ainsi que le premier contexte.
138
+
139
+ ------
140
+
141
+ ## Exploration
142
+ Si j'avais eu plus de temps, j'aurais :
143
+
144
+ - Généré un jeu de données plus complexe.
145
+ - Implémenté des KNN plutôt que d'utiliser le produit scalaire pour le retriever.
146
+ - Mieux exploré et analysé les questions mal classifiées.
147
+ - Mis en place un reformulateur de questions (multi-query, [hyde](https://arxiv.org/abs/2212.10496)).
148
+ - Exploré davantage Cohere et d'autres embedders.
149
+ - Mis en place ColBERT sur les K meilleurs contextes.
150
+
151
+ ---
152
+
153
+ reste a faire: regler le pb de pickle
154
+
155
+ replot le graph
156
+
157
+ test
158
+
159
+ linter
gossip_semantic_search/__init__.py ADDED
File without changes
gossip_semantic_search/constant.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ HEADERS = {
2
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) '
3
+ 'AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
4
+ }
5
+ CHANNEL_KEY = "channel"
6
+ ITEM_KEY = "item"
7
+ AUTHOR_KEY = "{http://purl.org/dc/elements/1.1/}creator"
8
+ TITLE_KEY = "title"
9
+ LINK_KEY = "link"
10
+ DESCRIPTION_KEY = "description"
11
+ PUBLICATION_DATE_KEY = "pubDate"
12
+ CONTENT_KEY = "{http://purl.org/rss/1.0/modules/content/}encoded"
13
+ LLAMA_70B_MODEL = "llama3-groq-70b-8192-tool-use-preview"
14
+ EMBEDING_MODEL = "embed-multilingual-v3.0"
15
+ DATE_FORMAT = "%a, %d %b %Y %H:%M:%S %z"
gossip_semantic_search/dataset_creator/__init__.py ADDED
File without changes
gossip_semantic_search/dataset_creator/__main__.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
5
+ sys.path.append(project_root)
6
+
7
+ import argparse
8
+
9
+ import yaml
10
+
11
+ from dataset_creator import DatasetCreator
12
+
13
+ def main():
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument(
16
+ '--config', type=str, help='Chemin vers le fichier YAML de configuration.', required=True
17
+ )
18
+
19
+ args = parser.parse_args()
20
+ with open(args.config, 'r') as file:
21
+ config = yaml.safe_load(file)
22
+
23
+ dataset_creator = DatasetCreator(urls=config.get('urls', []),
24
+ save_path=config.get('save_path'),
25
+ number_questions=config.get('number_questions'),
26
+ embed_articles=config.get('embed_articles'))
27
+ dataset_creator.run()
28
+
29
+ if __name__ == '__main__':
30
+ main()
31
+
32
+
gossip_semantic_search/dataset_creator/config.yml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ urls:
2
+ - "https://www.public.fr/feed"
3
+ - "https://www.vsd.fr/feed"
4
+ save_path: "datasets/dataset.pkl"
5
+ number_questions: 3
6
+ embed_articles: True
gossip_semantic_search/dataset_creator/dataset_creator.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ from typing import List
3
+ from xml.etree.ElementTree import fromstring
4
+
5
+ import requests
6
+ from tqdm import tqdm
7
+ from groq import Groq
8
+ from cohere import Client
9
+
10
+ from gossip_semantic_search.constant import HEADERS, CHANNEL_KEY, ITEM_KEY
11
+ from gossip_semantic_search.utils import (xml_to_dict, article_raw_to_article,
12
+ generates_questions, embed_content)
13
+ from gossip_semantic_search.models import Article
14
+
15
+ class DatasetCreator:
16
+ def __init__(self,
17
+ urls: List[str],
18
+ save_path: str = None,
19
+ number_questions:int = 0,
20
+ embed_articles: bool = False):
21
+
22
+ self.urls = urls
23
+ self.save_path = save_path
24
+ self.number_questions = number_questions
25
+ self.embed_articles = embed_articles
26
+
27
+ self.articles: List[Article] = []
28
+
29
+ def extract_articles(self):
30
+ for url in self.urls:
31
+ response = requests.get(url, headers=HEADERS)
32
+ xml_string = response.text
33
+ root = fromstring(xml_string)
34
+ articles_raw = xml_to_dict(root)[CHANNEL_KEY][ITEM_KEY]
35
+ self.articles.extend([article_raw_to_article(article_raw)
36
+ for article_raw in articles_raw])
37
+
38
+ def save_articles(self):
39
+ with open(self.save_path, 'wb') as f:
40
+ pickle.dump(self.articles, f)
41
+
42
+ def generate_questions(self):
43
+ client = Groq()
44
+ for article in tqdm(self.articles, desc="Generating questions"):
45
+ article.questions = generates_questions(article.content, self.number_questions, client)
46
+
47
+ def embed_article(self):
48
+ client = Client()
49
+ for article in tqdm(self.articles, desc="Embedding content"):
50
+ article.embeded_content = embed_content([article.content], client)[0, :]
51
+
52
+ def run(self):
53
+ self.extract_articles()
54
+
55
+ if self.number_questions:
56
+ self.generate_questions()
57
+
58
+ if self.embed_articles:
59
+ self.embed_article()
60
+
61
+ if self.articles:
62
+ self.save_articles()
gossip_semantic_search/dataset_processor/__init__.py ADDED
File without changes
gossip_semantic_search/dataset_processor/__main__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
5
+ sys.path.append(project_root)
6
+ print(project_root)
7
+
8
+ import argparse
9
+
10
+ import yaml
11
+
12
+ from dataset_processor import DatasetProcessor
13
+
14
+ def main():
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument(
17
+ '--config', type=str, help='Chemin vers le fichier YAML de configuration.', required=True
18
+ )
19
+
20
+ args = parser.parse_args()
21
+ with open(args.config, 'r') as file:
22
+ config = yaml.safe_load(file)
23
+
24
+ dataset_processor = DatasetProcessor(
25
+ dataset_path=config.get('dataset_path'),
26
+ saved_processed_dataset_path=config.get('saved_processed_dataset_path'))
27
+ dataset_processor.run()
28
+
29
+ if __name__ == '__main__':
30
+ main()
gossip_semantic_search/dataset_processor/config.yml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ dataset_path: "datasets/dataset.pkl"
2
+ saved_processed_dataset_path: "datasets/processed_dataset.pkl"
gossip_semantic_search/dataset_processor/dataset_processor.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pickle import dump
2
+ from typing import List
3
+
4
+ from cohere import Client
5
+ from numpy import array
6
+
7
+ from gossip_semantic_search.models import Article, ProcessedDataset
8
+ from gossip_semantic_search.utils import embed_content, CustomUnpickler
9
+
10
+
11
+ class DatasetProcessor:
12
+ def __init__(self,
13
+ dataset_path: str,
14
+ saved_processed_dataset_path: str):
15
+ self.dataset_path = dataset_path
16
+ self.saved_processed_dataset_path = saved_processed_dataset_path
17
+
18
+ self.processed_dataset: ProcessedDataset = None
19
+
20
+ @staticmethod
21
+ def load_dataset(dataset_path: str) -> List[Article]:
22
+ with open(dataset_path, 'rb') as file:
23
+ unpickler = CustomUnpickler(file)
24
+ data = unpickler.load()
25
+ return data
26
+
27
+
28
+ def process_dataset(self,
29
+ data: List[Article]):
30
+ client = Client()
31
+
32
+ y_true = []
33
+ questions = []
34
+ for i, sample in enumerate(data):
35
+ for question in sample.questions:
36
+ y_true.append(i)
37
+ questions.append(question)
38
+
39
+ self.processed_dataset = ProcessedDataset(
40
+ y_true = array(y_true),
41
+ embedded_queries=embed_content(questions, client),
42
+ embedded_context=array([sample.embeded_content for sample in data]))
43
+
44
+ def save_articles(self):
45
+ with open(self.saved_processed_dataset_path, 'wb') as f:
46
+ dump(self.processed_dataset, f)
47
+
48
+ def run(self):
49
+ data = self.load_dataset(self.dataset_path)
50
+ self.process_dataset(data)
51
+ self.save_articles()
gossip_semantic_search/evaluator/__main__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
5
+ sys.path.append(project_root)
6
+
7
+ import argparse
8
+
9
+ import yaml
10
+
11
+ from evaluator import Evaluator
12
+
13
+ def main():
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument('--config', type=str, help='Chemin vers le fichier YAML de configuration.', required=True)
16
+
17
+ args = parser.parse_args()
18
+ with open(args.config, 'r') as file:
19
+ config = yaml.safe_load(file)
20
+
21
+ evaluator = Evaluator(
22
+ processed_dataset_path=config.get('processed_dataset_path'),
23
+ k_max=config.get('k_max'),
24
+ path_k_plot=config.get('path_k_plot'),
25
+ )
26
+ evaluator.run()
27
+
28
+ if __name__ == '__main__':
29
+ main()
30
+
gossip_semantic_search/evaluator/config.yml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ processed_dataset_path: "datasets/processed_dataset.pkl"
2
+ k_max: 6
3
+ path_k_plot: "gossip_semantic_search/evaluator/k_plot_graph.png"
gossip_semantic_search/evaluator/evaluator.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pickle import dump, load
2
+
3
+ from numpy import argsort, dot, array, any, isin
4
+ from matplotlib import pyplot as plt
5
+ from gossip_semantic_search.models import ProcessedDataset
6
+ from gossip_semantic_search.utils import CustomUnpickler
7
+
8
+
9
+ class Evaluator:
10
+ def __init__(self,
11
+ processed_dataset_path: str,
12
+ k_max:int = None,
13
+ path_k_plot:str = None):
14
+
15
+ self.processed_dataset_path = processed_dataset_path
16
+ self.k_max = k_max
17
+ self.path_k_plot = path_k_plot
18
+
19
+ self.processed_dataset: ProcessedDataset = None
20
+
21
+ def load_dataset(self):
22
+ with open(self.processed_dataset_path, 'rb') as file:
23
+ unpickler = CustomUnpickler(file)
24
+ return unpickler.load()
25
+
26
+ def evaluate_k_context(self,
27
+ k:int) -> float:
28
+ similarity_matrix = dot(self.processed_dataset.embedded_queries, self.processed_dataset.embedded_context.T)
29
+ k_best_context = argsort(similarity_matrix, axis=1)[:, -k:][:, ::-1]
30
+ result = array([1 if any(isin(row, val)) else 0
31
+ for row, val in zip(self.processed_dataset.y_true, k_best_context)])
32
+ return result.mean()
33
+
34
+ def plot_top_k_graph(self):
35
+ accuracies = []
36
+ for k in range(1, self.k_max):
37
+ accuracies.append(self.evaluate_k_context(k))
38
+ plt.plot(accuracies, marker='o', color='b', linestyle='-', markersize=8, linewidth=2)
39
+
40
+ # Personnalisation du graphique
41
+ plt.title("Taux de contextes corrects parmi les k meilleurs contextes", fontsize=14, fontweight='bold')
42
+ plt.xlabel("k", fontsize=12)
43
+ plt.ylabel("Accuracy", fontsize=12)
44
+ plt.grid(True, which='both', linestyle='--', linewidth=0.5)
45
+ plt.xticks(fontsize=10)
46
+ plt.yticks(fontsize=10)
47
+ plt.legend()
48
+ plt.savefig(self.path_k_plot)
49
+
50
+ def run(self):
51
+ self.processed_dataset = self.load_dataset()
52
+ if self.k_max:
53
+ self.plot_top_k_graph()
gossip_semantic_search/evaluator/k_plot_graph.png ADDED
gossip_semantic_search/front/__init__.py ADDED
File without changes
gossip_semantic_search/front/app.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
5
+ sys.path.append(project_root)
6
+
7
+ import gradio as gr
8
+ from cohere import Client
9
+ from groq import Groq
10
+ from predictor import Predictor
11
+
12
+ COHERE_API_KEY="8hr6huyTDgAnbbU3WU4mnXTovfa2dwIeV0kc5Uf5"
13
+ GROQ_API_KEY="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"
14
+
15
+ def main():
16
+ embeding_client = Client(api_key=COHERE_API_KEY)
17
+ gen_client = Groq(api_key=GROQ_API_KEY)
18
+
19
+ predictor = Predictor(dataset_path="datasets/dataset.pkl",
20
+ embeding_client=embeding_client,
21
+ QA_boosted=True,
22
+ generative_client=gen_client,)
23
+ predictor.setup()
24
+
25
+ def make_prediction(query):
26
+ answer = predictor.make_prediction(query)
27
+ return answer.answer, answer.link, answer.content
28
+
29
+ iface = gr.Interface(
30
+ fn=make_prediction,
31
+ inputs=gr.Textbox(label="query", value="Alain Delon connait-il Anne-Elisabeth Lemoine?"),
32
+ outputs=[gr.Textbox(label="reponse"),
33
+ gr.Textbox(label="link"),
34
+ gr.Textbox(label="context")],
35
+ title="Gossip answering",
36
+ description="Poser une question Gossip, peut-etre y'aura t'il la reponse"
37
+ )
38
+
39
+ iface.launch()
40
+
41
+ if __name__ == '__main__':
42
+ main()
gossip_semantic_search/front/predictor.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from cohere import Client
4
+ from groq import Groq
5
+ from numpy import array, argsort, dot
6
+ from numpy.typing import NDArray
7
+
8
+ from gossip_semantic_search.models import Article, Answer
9
+ from gossip_semantic_search.utils import (CustomUnpickler, embed_content,
10
+ choose_context_and_answer_questions)
11
+
12
+
13
+ class Predictor:
14
+ def __init__(self,
15
+ dataset_path:str,
16
+ embeding_client: Client,
17
+ QA_boosted: bool = False,
18
+ generative_client: Groq = None):
19
+ self.dataset_path = dataset_path
20
+ self.embeding_client = embeding_client
21
+ self.generative_client = generative_client
22
+ self.QA_boosted = QA_boosted
23
+
24
+ self.dataset: List[Article] = []
25
+ self.embeded_contents: NDArray
26
+
27
+ @staticmethod
28
+ def load_dataset(dataset_path: str) -> List[Article]:
29
+ with open(dataset_path, 'rb') as file:
30
+ unpickler = CustomUnpickler(file)
31
+ data = unpickler.load()
32
+ return data
33
+
34
+ def setup(self):
35
+ self.dataset = self.load_dataset(self.dataset_path)
36
+ self.embeded_contents = array([sample.embeded_content for sample in self.dataset])
37
+
38
+ def make_prediction(self,
39
+ query:str) -> Answer :
40
+ embeded_query = embed_content([query], self.embeding_client)
41
+ similarity_vector = dot(embeded_query, self.embeded_contents.T)
42
+ k_best_article_index = argsort(similarity_vector, axis=1)[:, -3:][:, ::-1]
43
+
44
+ if self.QA_boosted:
45
+ k_best_article = [self.dataset[k] for k in k_best_article_index[0]]
46
+ return choose_context_and_answer_questions(
47
+ k_best_article,
48
+ query,
49
+ self.generative_client
50
+ )
51
+
52
+ best_article = self.dataset[k_best_article_index[0][0]]
53
+ return Answer(
54
+ answer = "pas de reponse, activer QA_boosted pour générer des reponses",
55
+ link = f"{best_article.link}",
56
+ content = f"{best_article.content}"
57
+ )
gossip_semantic_search/models.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ from datetime import datetime
4
+ from numpy.typing import NDArray
5
+
6
+ @dataclass
7
+ class Article:
8
+ author: str
9
+ title: str
10
+ link: str
11
+ description: str
12
+ published_date: datetime
13
+ content: str
14
+ embeded_content: NDArray = None
15
+ questions: list[str] = None
16
+
17
+ @dataclass
18
+ class Answer:
19
+ link: str
20
+ content: str
21
+ answer: str
22
+
23
+ @dataclass
24
+ class ProcessedDataset:
25
+ y_true: NDArray[int]
26
+ embedded_queries: NDArray[float]
27
+ embedded_context: NDArray[float]
gossip_semantic_search/prompts.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def generate_question_prompt(context, nb_questions):
2
+ return(f"À partir du contexte suivant : '{context}',"
3
+ f" \n générez {nb_questions} questions pertinentes en français qui pourraient "
4
+ f"être posées en lien avec ce contexte.\n"
5
+ f"Les questions devront avoir la forme :\n" +
6
+ "\n".join([f"{i + 1}. ceci est-il une question {i + 1}?" for i in range(nb_questions)]))
7
+
8
+
9
+ def generate_context_retriver_prompt(query, context):
10
+ return(f"je vais te donner 1 texte et une question, tu devras me dire si la reponse est dans le texte "
11
+ f"et si elle y est, me générer une reponse. \n"
12
+ f"voici la question: {query} \n "
13
+ f"voici le texte: {context} \n "
14
+ f"si la reponse est dans le texte ta reponse doit avor la forme suivante: \n "
15
+ f" (answer_in_text=True, answer='La réponse généré')"
16
+ f"si la reponse n'est pas dans le texte ta reponse doit avor la forme suivante:"
17
+ f" (answer_in_text=False, answer=None)")
gossip_semantic_search/utils.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ import warnings
3
+ from typing import List
4
+ from pickle import Unpickler
5
+ import re
6
+
7
+ from bs4 import BeautifulSoup
8
+ from groq import Groq
9
+ from cohere import Client
10
+ from numpy.typing import NDArray
11
+ from numpy import array
12
+
13
+ from gossip_semantic_search.models import Article, Answer
14
+ from gossip_semantic_search.constant import (AUTHOR_KEY, TITLE_KEY, LINK_KEY, DESCRIPTION_KEY,
15
+ PUBLICATION_DATE_KEY, CONTENT_KEY, LLAMA_70B_MODEL,
16
+ DATE_FORMAT, EMBEDING_MODEL)
17
+ from gossip_semantic_search.prompts import (generate_question_prompt,
18
+ generate_context_retriver_prompt)
19
+
20
+
21
+ def xml_to_dict(element):
22
+ result = {}
23
+
24
+ for child in element:
25
+ child_dict = xml_to_dict(child)
26
+ if child.tag in result:
27
+ if isinstance(result[child.tag], list):
28
+ result[child.tag].append(child_dict)
29
+ else:
30
+ result[child.tag] = [result[child.tag], child_dict]
31
+ else:
32
+ result[child.tag] = child_dict
33
+
34
+ if element.text and element.text.strip():
35
+ result = element.text.strip()
36
+
37
+ return result
38
+
39
+
40
+ def sanitize_html_content(html_content):
41
+ soup = BeautifulSoup(html_content, 'html.parser')
42
+
43
+ for a in soup.find_all('a'):
44
+ a.unwrap()
45
+
46
+ for tag in soup.find_all(['em', 'strong']):
47
+ tag.unwrap()
48
+
49
+ for blockquote in soup.find_all('blockquote'):
50
+ blockquote.extract()
51
+
52
+ cleaned_text = re.sub(r'\s+', ' ', soup.get_text()).strip()
53
+ return cleaned_text
54
+
55
+ def article_raw_to_article(raw_article) -> Article:
56
+ return Article(
57
+ author = raw_article[AUTHOR_KEY],
58
+ title = raw_article[TITLE_KEY],
59
+ link = raw_article[LINK_KEY],
60
+ description = raw_article[DESCRIPTION_KEY],
61
+ published_date = datetime.strptime(
62
+ raw_article[PUBLICATION_DATE_KEY],
63
+ DATE_FORMAT
64
+ ),
65
+ content = sanitize_html_content(raw_article[CONTENT_KEY])
66
+ )
67
+
68
+ def generates_questions(context: str,
69
+ nb_questions: int,
70
+ client: Groq) -> List[str]:
71
+ completion = client.chat.completions.create(
72
+ model=LLAMA_70B_MODEL,
73
+ messages=[
74
+ {
75
+ "role": "user",
76
+ "content": generate_question_prompt(context, nb_questions)
77
+ },
78
+ ],
79
+ temperature=1,
80
+ max_tokens=1024,
81
+ top_p=1,
82
+ stream=True,
83
+ stop=None,
84
+ )
85
+ questions_str = "".join(chunk.choices[0].delta.content or "" for chunk in completion)
86
+
87
+ try:
88
+ questions = re.findall(r'([^?]*\?)', questions_str)
89
+ questions = [question.strip()[3:] for question in questions]
90
+
91
+ except IndexError:
92
+ warnings.warn(f"no question found. \n"
93
+ f"string return: {questions_str}")
94
+ return []
95
+
96
+ if len(questions) != nb_questions:
97
+ warnings.warn(f"Expected {nb_questions} questions, but found "
98
+ f"{len(questions)}. {', '.join(questions)}", UserWarning)
99
+
100
+ return questions
101
+
102
+ def choose_context_and_answer_questions(articles: List[Article],
103
+ query:str,
104
+ generative_client) -> Answer:
105
+ for article in articles:
106
+ completion = generative_client.chat.completions.create(
107
+ model=LLAMA_70B_MODEL,
108
+ messages=[
109
+ {
110
+ "role": "user",
111
+ "content": generate_context_retriver_prompt(query, article.content)
112
+ },
113
+ ],
114
+ temperature=1,
115
+ max_tokens=1024,
116
+ top_p=1,
117
+ stream=True,
118
+ stop=None,
119
+ )
120
+ answer = "".join(chunk.choices[0].delta.content or "" for chunk in completion)
121
+ pattern = r"answer_in_text\s*=\s*(.*?),"
122
+
123
+ # Appliquer la regex
124
+ match = re.search(pattern, answer)
125
+ if match:
126
+ if match.group(1) == "True":
127
+ pattern = r"answer\s*=\s*(.*)"
128
+ match = re.search(pattern, answer)
129
+ if match:
130
+ answer_value = match.group(1)[1:-2]
131
+ return Answer(
132
+ answer = answer_value,
133
+ link = f"{article.link}",
134
+ content = f"{article.content}"
135
+ )
136
+
137
+ return Answer(
138
+ answer = "incapable de générer une reponse",
139
+ link = f"{articles[0].link}",
140
+ content = f"{articles[0].content}"
141
+ )
142
+
143
+
144
+ def embed_content(contexts:List[str],
145
+ client: Client) -> NDArray:
146
+ return array(client.embed(
147
+ model=EMBEDING_MODEL,
148
+ texts=contexts,
149
+ input_type='classification',
150
+ truncate='NONE'
151
+ ).embeddings)
152
+
153
+ class CustomUnpickler(Unpickler):
154
+ def find_class(self, module, name):
155
+ if module == 'models':
156
+ return Article # Renvoie une classe de remplacement
157
+ return super().find_class(module, name)
158
+
159
+
160
+
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ beautifulsoup4
2
+ cohere
3
+ gradio
4
+ groq
5
+ matplotlib
6
+ numpy
7
+ pylint
8
+ requests
9
+ tqdm