Spaces:
Sleeping
Sleeping
pierre Brault
commited on
Commit
·
3ff674d
1
Parent(s):
f0648ac
imit
Browse files- .gitignore +8 -0
- README.md +156 -9
- gossip_semantic_search/__init__.py +0 -0
- gossip_semantic_search/constant.py +15 -0
- gossip_semantic_search/dataset_creator/__init__.py +0 -0
- gossip_semantic_search/dataset_creator/__main__.py +32 -0
- gossip_semantic_search/dataset_creator/config.yml +6 -0
- gossip_semantic_search/dataset_creator/dataset_creator.py +62 -0
- gossip_semantic_search/dataset_processor/__init__.py +0 -0
- gossip_semantic_search/dataset_processor/__main__.py +30 -0
- gossip_semantic_search/dataset_processor/config.yml +2 -0
- gossip_semantic_search/dataset_processor/dataset_processor.py +51 -0
- gossip_semantic_search/evaluator/__main__.py +30 -0
- gossip_semantic_search/evaluator/config.yml +3 -0
- gossip_semantic_search/evaluator/evaluator.py +53 -0
- gossip_semantic_search/evaluator/k_plot_graph.png +0 -0
- gossip_semantic_search/front/__init__.py +0 -0
- gossip_semantic_search/front/app.py +42 -0
- gossip_semantic_search/front/predictor.py +57 -0
- gossip_semantic_search/models.py +27 -0
- gossip_semantic_search/prompts.py +17 -0
- gossip_semantic_search/utils.py +160 -0
- requirements.txt +9 -0
.gitignore
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
venv/
|
2 |
+
.gradio/
|
3 |
+
.idea/
|
4 |
+
set_env.sh
|
5 |
+
datasets/
|
6 |
+
explo.ipynb
|
7 |
+
*.pyc
|
8 |
+
.DS_Store
|
README.md
CHANGED
@@ -1,12 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
---
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Gossip Answering
|
2 |
+
|
3 |
+
## Quick Start
|
4 |
+
|
5 |
+
Cette application utilise **Gradio** pour créer une interface permettant de poser des questions sur des articles provenant de [VSD](https://vsd.fr/) et [Public](https://www.public.fr/) en utilisant des techniques de **Retrieval-Augmented Generation (RAG)**.
|
6 |
+
|
7 |
+
Ce projet utilise deux services API, et il vous sera nécessaire de générer des clés pour accéder à ces API. :
|
8 |
+
- **[Cohere](https://dashboard.cohere.com/api-keys)** pour générer des embeddings.
|
9 |
+
- **[Groq](https://console.groq.com/keys)** pour générer du contenu.
|
10 |
+
|
11 |
---
|
12 |
+
### colab ###
|
13 |
+
|
14 |
+
Vous pouvez accéder au [notebook disponible sur Google Colab ici.](https://colab.research.google.com/drive/1pG4-s2Tg-9o7U0bRZrV1WELTy2k42qk5?usp=drive_link)
|
15 |
+
|
16 |
+
|
17 |
+
Dans la première cellule, vous devrez entrer votre GitHub token pour cloner le repo ainsi que les clés des deux API utilisées par ce projet.
|
18 |
+
|
19 |
+
Une fois cela fait, exécutez toutes les cellules du notebook, puis rendez-vous sur le lien généré `.gradio.live` à la fin de la dernière cellule pour accéder à l'application.
|
20 |
+
|
21 |
+
---
|
22 |
+
### Local ###
|
23 |
+
#### Prérequis
|
24 |
+
|
25 |
+
- **Python 3.12.3** (La version utilisée pour ce projet)
|
26 |
+
|
27 |
+
#### Installation
|
28 |
+
|
29 |
+
1. **Clonez le dépôt :**
|
30 |
+
|
31 |
+
Clonez le projet sur votre machine locale en utilisant la commande suivante :
|
32 |
+
|
33 |
+
```
|
34 |
+
git clone https://github.com/PierreBdesG/gossip-semantic-search.git
|
35 |
+
```
|
36 |
+
|
37 |
+
2. **Récupérez les clés API :**
|
38 |
+
|
39 |
+
Définissez vos API keys comme variables d'environnement :
|
40 |
+
|
41 |
+
```
|
42 |
+
export COHERE_API_KEY=your_cohere_api_key
|
43 |
+
export GROQ_API_KEY=your_groq_api_key
|
44 |
+
```
|
45 |
+
|
46 |
+
3. **Installez les dépendances :**
|
47 |
+
|
48 |
+
Installez les bibliothèques nécessaires à l'aide du fichier `requirements.txt` :
|
49 |
+
|
50 |
+
```
|
51 |
+
pip install -r requirements.txt
|
52 |
+
```
|
53 |
+
|
54 |
+
4. **Téléchargez le dataset :**
|
55 |
+
|
56 |
+
Téléchargez le dataset (mis à jour le 19 octobre 2024) :
|
57 |
+
```
|
58 |
+
!curl -L -o dataset.pkl "https://drive.google.com/uc?export=download&id=1g4-pNY3CTRIW5_hinWBREMkTrsZ05GNd"
|
59 |
+
```
|
60 |
+
|
61 |
+
5. **Lancez l'application :**
|
62 |
+
|
63 |
+
Exécutez l'application avec la commande suivante, en fournissant le chemin vers le dataset téléchargé :
|
64 |
+
|
65 |
+
```
|
66 |
+
python gossip_semantic_search/front --dataset_path dataset.pkl
|
67 |
+
```
|
68 |
+
|
69 |
---
|
70 |
|
71 |
+
## Création du Dataset
|
72 |
+
|
73 |
+
Le dataset utilisé dans ce projet est une **`List[Article]`**, où chaque élément est un objet `Article` contenant les informations suivantes :
|
74 |
+
|
75 |
+
```
|
76 |
+
class Article:
|
77 |
+
author: str
|
78 |
+
title: str
|
79 |
+
link: str
|
80 |
+
description: str
|
81 |
+
published_date: datetime
|
82 |
+
content: str
|
83 |
+
embeded_content: NDArray
|
84 |
+
questions: list[str]
|
85 |
+
```
|
86 |
+
|
87 |
+
- Les attributs **`author`**, **`title`**, **`link`**, **`description`**, **`published_date`** et **`content`** sont extraits des flux RSS des sites [VSD](https://vsd.fr/feed) et [Public](https://www.public.fr/feed).
|
88 |
+
- Le champ **`embeded_content`** correspond aux embeddings générés pour chaque article. Une description détaillée de l'embedding utilisé est disponible [ici](https://cohere.com/blog/introducing-embed-v3).
|
89 |
+
- Le champ **`questions`** contient une liste de `n` questions associées à l'article. Ces questions ont été générées à l'aide de [`Llama-3-70B`](https://huggingface.co/meta-llama/Meta-Llama-3-70B), via l'API fournie par [Groq](https://groq.com/).
|
90 |
+
|
91 |
+
---
|
92 |
+
|
93 |
+
## Évaluation ##
|
94 |
+
|
95 |
+
Pour ce dataset, nous avons choisi de générer 3 questions par article. Avec un total de 60 articles, cela constitue un dataset d'environ 180 questions.
|
96 |
+
|
97 |
+
Une fois ce dataset créé, nous pouvons évaluer notre modèle de retrieval de la manière suivante :
|
98 |
+
- Effectuer un produit scalaire (similaire a un cosine similarity vue que les normes des embedding=1) entre tous les embeddings des questions et tous les embeddings des contextes.
|
99 |
+
- Pour chaque question, récupérer les `k` meilleurs contextes correspondant.
|
100 |
+
|
101 |
+
ce qui nous donne les resultats suivant:
|
102 |
+

|
103 |
+
|
104 |
+
/!\ regénerer le graph en décalant les absisse de 1
|
105 |
+
|
106 |
+
#### Résultats du Graph
|
107 |
+
|
108 |
+
Sur ce graphique, on observe les résultats suivants pour la récupération des contextes :
|
109 |
+
|
110 |
+
- Si l'on se base uniquement sur le **dot product** le plus élevé pour chaque question, dans **91.5% des cas**, le contexte d'origine est retrouvé.
|
111 |
+
- En élargissant la sélection aux **2 meilleurs contextes**, dans environ **99% des cas**, le contexte d'origine se trouve parmi les deux contextes sélectionnés.
|
112 |
+
- En prenant les **3 meilleurs contextes**, on atteint une précision de **99.5%**, où le contexte d'origine se retrouve dans les trois contextes sélectionnés.
|
113 |
+
|
114 |
+
|
115 |
+
#### Limitation ####
|
116 |
+
|
117 |
+
Ce dataset est un dataset **synthétique**. Les questions ont été générées par le modèle **[Llama-3-70B](https://huggingface.co/meta-llama/Meta-Llama-3-70B)**, en prenant en compte un certain article spécifique pour chaque question. Les questions sont donc conçues pour correspondre directement au contenu des articles, créant ainsi un ensemble de données biaisé par apport au réel avec
|
118 |
+
des questions bien alignées avec les informations présentes dans le texte.
|
119 |
+
Lors de l'inférence, les questions posées par les utilisateurs ne seront pas toujours aussi structurées ou spécifiques que celles générées pour ce dataset. Les questions réelles, posées par des utilisateurs humains, seront souvent plus organiques, plus variées, et moins directement liées à un contexte particulier
|
120 |
+
En conséquence, les performances du modèle seront donc moins élevées lorsqu'il sera confronté à des questions plus naturelles
|
121 |
+
|
122 |
+
---
|
123 |
+
|
124 |
+
## Inférence et génération de réponses ##
|
125 |
+
|
126 |
+
1.Une question est entrée par un utilisateur.
|
127 |
+
|
128 |
+
2. La question est convertie en embedding.
|
129 |
+
|
130 |
+
3. On effectue un dot product entre l'embedding de la question et tous les articles embeddés, qui ont été chargés en RAM lors du lancement de l'application.
|
131 |
+
|
132 |
+
4. On récupère les 3 articles les plus pertinents.
|
133 |
+
|
134 |
+
5. On prend le premier article et on demande à **[Llama-3-70B](https://huggingface.co/meta-llama/Meta-Llama-3-70B)** si la réponse à la question se trouve dans le contexte de cet article.
|
135 |
+
- Si oui, on demande à **Llama-3-70B** de générer la réponse, puis on renvoie la réponse, le lien de l'article et le contexte correspondant.
|
136 |
+
- Si non, on passe au deuxième article et répète la même procédure.
|
137 |
+
- Si la réponse n'est trouvée dans aucun des 3 articles, on renvoie le message **"incapable de générer une réponse"** ainsi que le premier contexte.
|
138 |
+
|
139 |
+
------
|
140 |
+
|
141 |
+
## Exploration
|
142 |
+
Si j'avais eu plus de temps, j'aurais :
|
143 |
+
|
144 |
+
- Généré un jeu de données plus complexe.
|
145 |
+
- Implémenté des KNN plutôt que d'utiliser le produit scalaire pour le retriever.
|
146 |
+
- Mieux exploré et analysé les questions mal classifiées.
|
147 |
+
- Mis en place un reformulateur de questions (multi-query, [hyde](https://arxiv.org/abs/2212.10496)).
|
148 |
+
- Exploré davantage Cohere et d'autres embedders.
|
149 |
+
- Mis en place ColBERT sur les K meilleurs contextes.
|
150 |
+
|
151 |
+
---
|
152 |
+
|
153 |
+
reste a faire: regler le pb de pickle
|
154 |
+
|
155 |
+
replot le graph
|
156 |
+
|
157 |
+
test
|
158 |
+
|
159 |
+
linter
|
gossip_semantic_search/__init__.py
ADDED
File without changes
|
gossip_semantic_search/constant.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
HEADERS = {
|
2 |
+
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) '
|
3 |
+
'AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36'
|
4 |
+
}
|
5 |
+
CHANNEL_KEY = "channel"
|
6 |
+
ITEM_KEY = "item"
|
7 |
+
AUTHOR_KEY = "{http://purl.org/dc/elements/1.1/}creator"
|
8 |
+
TITLE_KEY = "title"
|
9 |
+
LINK_KEY = "link"
|
10 |
+
DESCRIPTION_KEY = "description"
|
11 |
+
PUBLICATION_DATE_KEY = "pubDate"
|
12 |
+
CONTENT_KEY = "{http://purl.org/rss/1.0/modules/content/}encoded"
|
13 |
+
LLAMA_70B_MODEL = "llama3-groq-70b-8192-tool-use-preview"
|
14 |
+
EMBEDING_MODEL = "embed-multilingual-v3.0"
|
15 |
+
DATE_FORMAT = "%a, %d %b %Y %H:%M:%S %z"
|
gossip_semantic_search/dataset_creator/__init__.py
ADDED
File without changes
|
gossip_semantic_search/dataset_creator/__main__.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
|
4 |
+
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
|
5 |
+
sys.path.append(project_root)
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
|
9 |
+
import yaml
|
10 |
+
|
11 |
+
from dataset_creator import DatasetCreator
|
12 |
+
|
13 |
+
def main():
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
+
parser.add_argument(
|
16 |
+
'--config', type=str, help='Chemin vers le fichier YAML de configuration.', required=True
|
17 |
+
)
|
18 |
+
|
19 |
+
args = parser.parse_args()
|
20 |
+
with open(args.config, 'r') as file:
|
21 |
+
config = yaml.safe_load(file)
|
22 |
+
|
23 |
+
dataset_creator = DatasetCreator(urls=config.get('urls', []),
|
24 |
+
save_path=config.get('save_path'),
|
25 |
+
number_questions=config.get('number_questions'),
|
26 |
+
embed_articles=config.get('embed_articles'))
|
27 |
+
dataset_creator.run()
|
28 |
+
|
29 |
+
if __name__ == '__main__':
|
30 |
+
main()
|
31 |
+
|
32 |
+
|
gossip_semantic_search/dataset_creator/config.yml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
urls:
|
2 |
+
- "https://www.public.fr/feed"
|
3 |
+
- "https://www.vsd.fr/feed"
|
4 |
+
save_path: "datasets/dataset.pkl"
|
5 |
+
number_questions: 3
|
6 |
+
embed_articles: True
|
gossip_semantic_search/dataset_creator/dataset_creator.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
from typing import List
|
3 |
+
from xml.etree.ElementTree import fromstring
|
4 |
+
|
5 |
+
import requests
|
6 |
+
from tqdm import tqdm
|
7 |
+
from groq import Groq
|
8 |
+
from cohere import Client
|
9 |
+
|
10 |
+
from gossip_semantic_search.constant import HEADERS, CHANNEL_KEY, ITEM_KEY
|
11 |
+
from gossip_semantic_search.utils import (xml_to_dict, article_raw_to_article,
|
12 |
+
generates_questions, embed_content)
|
13 |
+
from gossip_semantic_search.models import Article
|
14 |
+
|
15 |
+
class DatasetCreator:
|
16 |
+
def __init__(self,
|
17 |
+
urls: List[str],
|
18 |
+
save_path: str = None,
|
19 |
+
number_questions:int = 0,
|
20 |
+
embed_articles: bool = False):
|
21 |
+
|
22 |
+
self.urls = urls
|
23 |
+
self.save_path = save_path
|
24 |
+
self.number_questions = number_questions
|
25 |
+
self.embed_articles = embed_articles
|
26 |
+
|
27 |
+
self.articles: List[Article] = []
|
28 |
+
|
29 |
+
def extract_articles(self):
|
30 |
+
for url in self.urls:
|
31 |
+
response = requests.get(url, headers=HEADERS)
|
32 |
+
xml_string = response.text
|
33 |
+
root = fromstring(xml_string)
|
34 |
+
articles_raw = xml_to_dict(root)[CHANNEL_KEY][ITEM_KEY]
|
35 |
+
self.articles.extend([article_raw_to_article(article_raw)
|
36 |
+
for article_raw in articles_raw])
|
37 |
+
|
38 |
+
def save_articles(self):
|
39 |
+
with open(self.save_path, 'wb') as f:
|
40 |
+
pickle.dump(self.articles, f)
|
41 |
+
|
42 |
+
def generate_questions(self):
|
43 |
+
client = Groq()
|
44 |
+
for article in tqdm(self.articles, desc="Generating questions"):
|
45 |
+
article.questions = generates_questions(article.content, self.number_questions, client)
|
46 |
+
|
47 |
+
def embed_article(self):
|
48 |
+
client = Client()
|
49 |
+
for article in tqdm(self.articles, desc="Embedding content"):
|
50 |
+
article.embeded_content = embed_content([article.content], client)[0, :]
|
51 |
+
|
52 |
+
def run(self):
|
53 |
+
self.extract_articles()
|
54 |
+
|
55 |
+
if self.number_questions:
|
56 |
+
self.generate_questions()
|
57 |
+
|
58 |
+
if self.embed_articles:
|
59 |
+
self.embed_article()
|
60 |
+
|
61 |
+
if self.articles:
|
62 |
+
self.save_articles()
|
gossip_semantic_search/dataset_processor/__init__.py
ADDED
File without changes
|
gossip_semantic_search/dataset_processor/__main__.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
|
4 |
+
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
|
5 |
+
sys.path.append(project_root)
|
6 |
+
print(project_root)
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
|
10 |
+
import yaml
|
11 |
+
|
12 |
+
from dataset_processor import DatasetProcessor
|
13 |
+
|
14 |
+
def main():
|
15 |
+
parser = argparse.ArgumentParser()
|
16 |
+
parser.add_argument(
|
17 |
+
'--config', type=str, help='Chemin vers le fichier YAML de configuration.', required=True
|
18 |
+
)
|
19 |
+
|
20 |
+
args = parser.parse_args()
|
21 |
+
with open(args.config, 'r') as file:
|
22 |
+
config = yaml.safe_load(file)
|
23 |
+
|
24 |
+
dataset_processor = DatasetProcessor(
|
25 |
+
dataset_path=config.get('dataset_path'),
|
26 |
+
saved_processed_dataset_path=config.get('saved_processed_dataset_path'))
|
27 |
+
dataset_processor.run()
|
28 |
+
|
29 |
+
if __name__ == '__main__':
|
30 |
+
main()
|
gossip_semantic_search/dataset_processor/config.yml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
dataset_path: "datasets/dataset.pkl"
|
2 |
+
saved_processed_dataset_path: "datasets/processed_dataset.pkl"
|
gossip_semantic_search/dataset_processor/dataset_processor.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pickle import dump
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
from cohere import Client
|
5 |
+
from numpy import array
|
6 |
+
|
7 |
+
from gossip_semantic_search.models import Article, ProcessedDataset
|
8 |
+
from gossip_semantic_search.utils import embed_content, CustomUnpickler
|
9 |
+
|
10 |
+
|
11 |
+
class DatasetProcessor:
|
12 |
+
def __init__(self,
|
13 |
+
dataset_path: str,
|
14 |
+
saved_processed_dataset_path: str):
|
15 |
+
self.dataset_path = dataset_path
|
16 |
+
self.saved_processed_dataset_path = saved_processed_dataset_path
|
17 |
+
|
18 |
+
self.processed_dataset: ProcessedDataset = None
|
19 |
+
|
20 |
+
@staticmethod
|
21 |
+
def load_dataset(dataset_path: str) -> List[Article]:
|
22 |
+
with open(dataset_path, 'rb') as file:
|
23 |
+
unpickler = CustomUnpickler(file)
|
24 |
+
data = unpickler.load()
|
25 |
+
return data
|
26 |
+
|
27 |
+
|
28 |
+
def process_dataset(self,
|
29 |
+
data: List[Article]):
|
30 |
+
client = Client()
|
31 |
+
|
32 |
+
y_true = []
|
33 |
+
questions = []
|
34 |
+
for i, sample in enumerate(data):
|
35 |
+
for question in sample.questions:
|
36 |
+
y_true.append(i)
|
37 |
+
questions.append(question)
|
38 |
+
|
39 |
+
self.processed_dataset = ProcessedDataset(
|
40 |
+
y_true = array(y_true),
|
41 |
+
embedded_queries=embed_content(questions, client),
|
42 |
+
embedded_context=array([sample.embeded_content for sample in data]))
|
43 |
+
|
44 |
+
def save_articles(self):
|
45 |
+
with open(self.saved_processed_dataset_path, 'wb') as f:
|
46 |
+
dump(self.processed_dataset, f)
|
47 |
+
|
48 |
+
def run(self):
|
49 |
+
data = self.load_dataset(self.dataset_path)
|
50 |
+
self.process_dataset(data)
|
51 |
+
self.save_articles()
|
gossip_semantic_search/evaluator/__main__.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
|
4 |
+
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
|
5 |
+
sys.path.append(project_root)
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
|
9 |
+
import yaml
|
10 |
+
|
11 |
+
from evaluator import Evaluator
|
12 |
+
|
13 |
+
def main():
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
+
parser.add_argument('--config', type=str, help='Chemin vers le fichier YAML de configuration.', required=True)
|
16 |
+
|
17 |
+
args = parser.parse_args()
|
18 |
+
with open(args.config, 'r') as file:
|
19 |
+
config = yaml.safe_load(file)
|
20 |
+
|
21 |
+
evaluator = Evaluator(
|
22 |
+
processed_dataset_path=config.get('processed_dataset_path'),
|
23 |
+
k_max=config.get('k_max'),
|
24 |
+
path_k_plot=config.get('path_k_plot'),
|
25 |
+
)
|
26 |
+
evaluator.run()
|
27 |
+
|
28 |
+
if __name__ == '__main__':
|
29 |
+
main()
|
30 |
+
|
gossip_semantic_search/evaluator/config.yml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
processed_dataset_path: "datasets/processed_dataset.pkl"
|
2 |
+
k_max: 6
|
3 |
+
path_k_plot: "gossip_semantic_search/evaluator/k_plot_graph.png"
|
gossip_semantic_search/evaluator/evaluator.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pickle import dump, load
|
2 |
+
|
3 |
+
from numpy import argsort, dot, array, any, isin
|
4 |
+
from matplotlib import pyplot as plt
|
5 |
+
from gossip_semantic_search.models import ProcessedDataset
|
6 |
+
from gossip_semantic_search.utils import CustomUnpickler
|
7 |
+
|
8 |
+
|
9 |
+
class Evaluator:
|
10 |
+
def __init__(self,
|
11 |
+
processed_dataset_path: str,
|
12 |
+
k_max:int = None,
|
13 |
+
path_k_plot:str = None):
|
14 |
+
|
15 |
+
self.processed_dataset_path = processed_dataset_path
|
16 |
+
self.k_max = k_max
|
17 |
+
self.path_k_plot = path_k_plot
|
18 |
+
|
19 |
+
self.processed_dataset: ProcessedDataset = None
|
20 |
+
|
21 |
+
def load_dataset(self):
|
22 |
+
with open(self.processed_dataset_path, 'rb') as file:
|
23 |
+
unpickler = CustomUnpickler(file)
|
24 |
+
return unpickler.load()
|
25 |
+
|
26 |
+
def evaluate_k_context(self,
|
27 |
+
k:int) -> float:
|
28 |
+
similarity_matrix = dot(self.processed_dataset.embedded_queries, self.processed_dataset.embedded_context.T)
|
29 |
+
k_best_context = argsort(similarity_matrix, axis=1)[:, -k:][:, ::-1]
|
30 |
+
result = array([1 if any(isin(row, val)) else 0
|
31 |
+
for row, val in zip(self.processed_dataset.y_true, k_best_context)])
|
32 |
+
return result.mean()
|
33 |
+
|
34 |
+
def plot_top_k_graph(self):
|
35 |
+
accuracies = []
|
36 |
+
for k in range(1, self.k_max):
|
37 |
+
accuracies.append(self.evaluate_k_context(k))
|
38 |
+
plt.plot(accuracies, marker='o', color='b', linestyle='-', markersize=8, linewidth=2)
|
39 |
+
|
40 |
+
# Personnalisation du graphique
|
41 |
+
plt.title("Taux de contextes corrects parmi les k meilleurs contextes", fontsize=14, fontweight='bold')
|
42 |
+
plt.xlabel("k", fontsize=12)
|
43 |
+
plt.ylabel("Accuracy", fontsize=12)
|
44 |
+
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
|
45 |
+
plt.xticks(fontsize=10)
|
46 |
+
plt.yticks(fontsize=10)
|
47 |
+
plt.legend()
|
48 |
+
plt.savefig(self.path_k_plot)
|
49 |
+
|
50 |
+
def run(self):
|
51 |
+
self.processed_dataset = self.load_dataset()
|
52 |
+
if self.k_max:
|
53 |
+
self.plot_top_k_graph()
|
gossip_semantic_search/evaluator/k_plot_graph.png
ADDED
![]() |
gossip_semantic_search/front/__init__.py
ADDED
File without changes
|
gossip_semantic_search/front/app.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
|
4 |
+
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
|
5 |
+
sys.path.append(project_root)
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
from cohere import Client
|
9 |
+
from groq import Groq
|
10 |
+
from predictor import Predictor
|
11 |
+
|
12 |
+
COHERE_API_KEY="8hr6huyTDgAnbbU3WU4mnXTovfa2dwIeV0kc5Uf5"
|
13 |
+
GROQ_API_KEY="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"
|
14 |
+
|
15 |
+
def main():
|
16 |
+
embeding_client = Client(api_key=COHERE_API_KEY)
|
17 |
+
gen_client = Groq(api_key=GROQ_API_KEY)
|
18 |
+
|
19 |
+
predictor = Predictor(dataset_path="datasets/dataset.pkl",
|
20 |
+
embeding_client=embeding_client,
|
21 |
+
QA_boosted=True,
|
22 |
+
generative_client=gen_client,)
|
23 |
+
predictor.setup()
|
24 |
+
|
25 |
+
def make_prediction(query):
|
26 |
+
answer = predictor.make_prediction(query)
|
27 |
+
return answer.answer, answer.link, answer.content
|
28 |
+
|
29 |
+
iface = gr.Interface(
|
30 |
+
fn=make_prediction,
|
31 |
+
inputs=gr.Textbox(label="query", value="Alain Delon connait-il Anne-Elisabeth Lemoine?"),
|
32 |
+
outputs=[gr.Textbox(label="reponse"),
|
33 |
+
gr.Textbox(label="link"),
|
34 |
+
gr.Textbox(label="context")],
|
35 |
+
title="Gossip answering",
|
36 |
+
description="Poser une question Gossip, peut-etre y'aura t'il la reponse"
|
37 |
+
)
|
38 |
+
|
39 |
+
iface.launch()
|
40 |
+
|
41 |
+
if __name__ == '__main__':
|
42 |
+
main()
|
gossip_semantic_search/front/predictor.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
from cohere import Client
|
4 |
+
from groq import Groq
|
5 |
+
from numpy import array, argsort, dot
|
6 |
+
from numpy.typing import NDArray
|
7 |
+
|
8 |
+
from gossip_semantic_search.models import Article, Answer
|
9 |
+
from gossip_semantic_search.utils import (CustomUnpickler, embed_content,
|
10 |
+
choose_context_and_answer_questions)
|
11 |
+
|
12 |
+
|
13 |
+
class Predictor:
|
14 |
+
def __init__(self,
|
15 |
+
dataset_path:str,
|
16 |
+
embeding_client: Client,
|
17 |
+
QA_boosted: bool = False,
|
18 |
+
generative_client: Groq = None):
|
19 |
+
self.dataset_path = dataset_path
|
20 |
+
self.embeding_client = embeding_client
|
21 |
+
self.generative_client = generative_client
|
22 |
+
self.QA_boosted = QA_boosted
|
23 |
+
|
24 |
+
self.dataset: List[Article] = []
|
25 |
+
self.embeded_contents: NDArray
|
26 |
+
|
27 |
+
@staticmethod
|
28 |
+
def load_dataset(dataset_path: str) -> List[Article]:
|
29 |
+
with open(dataset_path, 'rb') as file:
|
30 |
+
unpickler = CustomUnpickler(file)
|
31 |
+
data = unpickler.load()
|
32 |
+
return data
|
33 |
+
|
34 |
+
def setup(self):
|
35 |
+
self.dataset = self.load_dataset(self.dataset_path)
|
36 |
+
self.embeded_contents = array([sample.embeded_content for sample in self.dataset])
|
37 |
+
|
38 |
+
def make_prediction(self,
|
39 |
+
query:str) -> Answer :
|
40 |
+
embeded_query = embed_content([query], self.embeding_client)
|
41 |
+
similarity_vector = dot(embeded_query, self.embeded_contents.T)
|
42 |
+
k_best_article_index = argsort(similarity_vector, axis=1)[:, -3:][:, ::-1]
|
43 |
+
|
44 |
+
if self.QA_boosted:
|
45 |
+
k_best_article = [self.dataset[k] for k in k_best_article_index[0]]
|
46 |
+
return choose_context_and_answer_questions(
|
47 |
+
k_best_article,
|
48 |
+
query,
|
49 |
+
self.generative_client
|
50 |
+
)
|
51 |
+
|
52 |
+
best_article = self.dataset[k_best_article_index[0][0]]
|
53 |
+
return Answer(
|
54 |
+
answer = "pas de reponse, activer QA_boosted pour générer des reponses",
|
55 |
+
link = f"{best_article.link}",
|
56 |
+
content = f"{best_article.content}"
|
57 |
+
)
|
gossip_semantic_search/models.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
|
3 |
+
from datetime import datetime
|
4 |
+
from numpy.typing import NDArray
|
5 |
+
|
6 |
+
@dataclass
|
7 |
+
class Article:
|
8 |
+
author: str
|
9 |
+
title: str
|
10 |
+
link: str
|
11 |
+
description: str
|
12 |
+
published_date: datetime
|
13 |
+
content: str
|
14 |
+
embeded_content: NDArray = None
|
15 |
+
questions: list[str] = None
|
16 |
+
|
17 |
+
@dataclass
|
18 |
+
class Answer:
|
19 |
+
link: str
|
20 |
+
content: str
|
21 |
+
answer: str
|
22 |
+
|
23 |
+
@dataclass
|
24 |
+
class ProcessedDataset:
|
25 |
+
y_true: NDArray[int]
|
26 |
+
embedded_queries: NDArray[float]
|
27 |
+
embedded_context: NDArray[float]
|
gossip_semantic_search/prompts.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def generate_question_prompt(context, nb_questions):
|
2 |
+
return(f"À partir du contexte suivant : '{context}',"
|
3 |
+
f" \n générez {nb_questions} questions pertinentes en français qui pourraient "
|
4 |
+
f"être posées en lien avec ce contexte.\n"
|
5 |
+
f"Les questions devront avoir la forme :\n" +
|
6 |
+
"\n".join([f"{i + 1}. ceci est-il une question {i + 1}?" for i in range(nb_questions)]))
|
7 |
+
|
8 |
+
|
9 |
+
def generate_context_retriver_prompt(query, context):
|
10 |
+
return(f"je vais te donner 1 texte et une question, tu devras me dire si la reponse est dans le texte "
|
11 |
+
f"et si elle y est, me générer une reponse. \n"
|
12 |
+
f"voici la question: {query} \n "
|
13 |
+
f"voici le texte: {context} \n "
|
14 |
+
f"si la reponse est dans le texte ta reponse doit avor la forme suivante: \n "
|
15 |
+
f" (answer_in_text=True, answer='La réponse généré')"
|
16 |
+
f"si la reponse n'est pas dans le texte ta reponse doit avor la forme suivante:"
|
17 |
+
f" (answer_in_text=False, answer=None)")
|
gossip_semantic_search/utils.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import datetime
|
2 |
+
import warnings
|
3 |
+
from typing import List
|
4 |
+
from pickle import Unpickler
|
5 |
+
import re
|
6 |
+
|
7 |
+
from bs4 import BeautifulSoup
|
8 |
+
from groq import Groq
|
9 |
+
from cohere import Client
|
10 |
+
from numpy.typing import NDArray
|
11 |
+
from numpy import array
|
12 |
+
|
13 |
+
from gossip_semantic_search.models import Article, Answer
|
14 |
+
from gossip_semantic_search.constant import (AUTHOR_KEY, TITLE_KEY, LINK_KEY, DESCRIPTION_KEY,
|
15 |
+
PUBLICATION_DATE_KEY, CONTENT_KEY, LLAMA_70B_MODEL,
|
16 |
+
DATE_FORMAT, EMBEDING_MODEL)
|
17 |
+
from gossip_semantic_search.prompts import (generate_question_prompt,
|
18 |
+
generate_context_retriver_prompt)
|
19 |
+
|
20 |
+
|
21 |
+
def xml_to_dict(element):
|
22 |
+
result = {}
|
23 |
+
|
24 |
+
for child in element:
|
25 |
+
child_dict = xml_to_dict(child)
|
26 |
+
if child.tag in result:
|
27 |
+
if isinstance(result[child.tag], list):
|
28 |
+
result[child.tag].append(child_dict)
|
29 |
+
else:
|
30 |
+
result[child.tag] = [result[child.tag], child_dict]
|
31 |
+
else:
|
32 |
+
result[child.tag] = child_dict
|
33 |
+
|
34 |
+
if element.text and element.text.strip():
|
35 |
+
result = element.text.strip()
|
36 |
+
|
37 |
+
return result
|
38 |
+
|
39 |
+
|
40 |
+
def sanitize_html_content(html_content):
|
41 |
+
soup = BeautifulSoup(html_content, 'html.parser')
|
42 |
+
|
43 |
+
for a in soup.find_all('a'):
|
44 |
+
a.unwrap()
|
45 |
+
|
46 |
+
for tag in soup.find_all(['em', 'strong']):
|
47 |
+
tag.unwrap()
|
48 |
+
|
49 |
+
for blockquote in soup.find_all('blockquote'):
|
50 |
+
blockquote.extract()
|
51 |
+
|
52 |
+
cleaned_text = re.sub(r'\s+', ' ', soup.get_text()).strip()
|
53 |
+
return cleaned_text
|
54 |
+
|
55 |
+
def article_raw_to_article(raw_article) -> Article:
|
56 |
+
return Article(
|
57 |
+
author = raw_article[AUTHOR_KEY],
|
58 |
+
title = raw_article[TITLE_KEY],
|
59 |
+
link = raw_article[LINK_KEY],
|
60 |
+
description = raw_article[DESCRIPTION_KEY],
|
61 |
+
published_date = datetime.strptime(
|
62 |
+
raw_article[PUBLICATION_DATE_KEY],
|
63 |
+
DATE_FORMAT
|
64 |
+
),
|
65 |
+
content = sanitize_html_content(raw_article[CONTENT_KEY])
|
66 |
+
)
|
67 |
+
|
68 |
+
def generates_questions(context: str,
|
69 |
+
nb_questions: int,
|
70 |
+
client: Groq) -> List[str]:
|
71 |
+
completion = client.chat.completions.create(
|
72 |
+
model=LLAMA_70B_MODEL,
|
73 |
+
messages=[
|
74 |
+
{
|
75 |
+
"role": "user",
|
76 |
+
"content": generate_question_prompt(context, nb_questions)
|
77 |
+
},
|
78 |
+
],
|
79 |
+
temperature=1,
|
80 |
+
max_tokens=1024,
|
81 |
+
top_p=1,
|
82 |
+
stream=True,
|
83 |
+
stop=None,
|
84 |
+
)
|
85 |
+
questions_str = "".join(chunk.choices[0].delta.content or "" for chunk in completion)
|
86 |
+
|
87 |
+
try:
|
88 |
+
questions = re.findall(r'([^?]*\?)', questions_str)
|
89 |
+
questions = [question.strip()[3:] for question in questions]
|
90 |
+
|
91 |
+
except IndexError:
|
92 |
+
warnings.warn(f"no question found. \n"
|
93 |
+
f"string return: {questions_str}")
|
94 |
+
return []
|
95 |
+
|
96 |
+
if len(questions) != nb_questions:
|
97 |
+
warnings.warn(f"Expected {nb_questions} questions, but found "
|
98 |
+
f"{len(questions)}. {', '.join(questions)}", UserWarning)
|
99 |
+
|
100 |
+
return questions
|
101 |
+
|
102 |
+
def choose_context_and_answer_questions(articles: List[Article],
|
103 |
+
query:str,
|
104 |
+
generative_client) -> Answer:
|
105 |
+
for article in articles:
|
106 |
+
completion = generative_client.chat.completions.create(
|
107 |
+
model=LLAMA_70B_MODEL,
|
108 |
+
messages=[
|
109 |
+
{
|
110 |
+
"role": "user",
|
111 |
+
"content": generate_context_retriver_prompt(query, article.content)
|
112 |
+
},
|
113 |
+
],
|
114 |
+
temperature=1,
|
115 |
+
max_tokens=1024,
|
116 |
+
top_p=1,
|
117 |
+
stream=True,
|
118 |
+
stop=None,
|
119 |
+
)
|
120 |
+
answer = "".join(chunk.choices[0].delta.content or "" for chunk in completion)
|
121 |
+
pattern = r"answer_in_text\s*=\s*(.*?),"
|
122 |
+
|
123 |
+
# Appliquer la regex
|
124 |
+
match = re.search(pattern, answer)
|
125 |
+
if match:
|
126 |
+
if match.group(1) == "True":
|
127 |
+
pattern = r"answer\s*=\s*(.*)"
|
128 |
+
match = re.search(pattern, answer)
|
129 |
+
if match:
|
130 |
+
answer_value = match.group(1)[1:-2]
|
131 |
+
return Answer(
|
132 |
+
answer = answer_value,
|
133 |
+
link = f"{article.link}",
|
134 |
+
content = f"{article.content}"
|
135 |
+
)
|
136 |
+
|
137 |
+
return Answer(
|
138 |
+
answer = "incapable de générer une reponse",
|
139 |
+
link = f"{articles[0].link}",
|
140 |
+
content = f"{articles[0].content}"
|
141 |
+
)
|
142 |
+
|
143 |
+
|
144 |
+
def embed_content(contexts:List[str],
|
145 |
+
client: Client) -> NDArray:
|
146 |
+
return array(client.embed(
|
147 |
+
model=EMBEDING_MODEL,
|
148 |
+
texts=contexts,
|
149 |
+
input_type='classification',
|
150 |
+
truncate='NONE'
|
151 |
+
).embeddings)
|
152 |
+
|
153 |
+
class CustomUnpickler(Unpickler):
|
154 |
+
def find_class(self, module, name):
|
155 |
+
if module == 'models':
|
156 |
+
return Article # Renvoie une classe de remplacement
|
157 |
+
return super().find_class(module, name)
|
158 |
+
|
159 |
+
|
160 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
beautifulsoup4
|
2 |
+
cohere
|
3 |
+
gradio
|
4 |
+
groq
|
5 |
+
matplotlib
|
6 |
+
numpy
|
7 |
+
pylint
|
8 |
+
requests
|
9 |
+
tqdm
|