acumplid commited on
Commit
f75ccae
·
1 Parent(s): cdb6cea

initial commit

Browse files
Files changed (7) hide show
  1. README.md +5 -5
  2. app.py +246 -0
  3. handler.py +14 -0
  4. input_reader.py +22 -0
  5. rag.py +180 -0
  6. requirements.txt +14 -0
  7. utils.py +33 -0
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Wirag
3
- emoji: 🐠
4
- colorFrom: red
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 5.27.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
 
1
  ---
2
+ title: EADOP RAG
3
+ emoji: 💻
4
+ colorFrom: indigo
5
+ colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 4.24.0
8
  app_file: app.py
9
  pinned: false
10
  license: apache-2.0
app.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from gradio.components import Textbox, Button, Slider, Checkbox
4
+ from AinaTheme import theme
5
+ from urllib.error import HTTPError
6
+
7
+ from rag import RAG
8
+ from utils import setup
9
+
10
+ MAX_NEW_TOKENS = 700
11
+ SHOW_MODEL_PARAMETERS_IN_UI = os.environ.get("SHOW_MODEL_PARAMETERS_IN_UI", default="True") == "True"
12
+
13
+ setup()
14
+
15
+ rag = RAG(
16
+ vs_hf_repo_path=os.getenv("VS_REPO_NAME"),
17
+ vectorstore_path=os.getenv("VECTORSTORE_PATH"),
18
+ hf_token=os.getenv("HF_TOKEN"),
19
+ embeddings_model=os.getenv("EMBEDDINGS"),
20
+ model_name=os.getenv("MODEL"),
21
+ rerank_model=os.getenv("RERANK_MODEL"),
22
+ rerank_number_contexts=int(os.getenv("RERANK_NUMBER_CONTEXTS"))
23
+ )
24
+
25
+
26
+ def generate(prompt, model_parameters):
27
+ try:
28
+ output, context, source = rag.get_response(prompt, model_parameters)
29
+ return output, context, source
30
+ except HTTPError as err:
31
+ if err.code == 400:
32
+ gr.Warning(
33
+ "The inference endpoint is only available Monday through Friday, from 08:00 to 20:00 CET."
34
+ )
35
+ except:
36
+ gr.Warning(
37
+ "Inference endpoint is not available right now. Please try again later."
38
+ )
39
+ return None, None, None
40
+
41
+
42
+ def submit_input(input_, num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature):
43
+ if input_.strip() == "":
44
+ gr.Warning("Not possible to inference an empty input")
45
+ return None
46
+
47
+
48
+ model_parameters = {
49
+ "NUM_CHUNKS": num_chunks,
50
+ "max_new_tokens": max_new_tokens,
51
+ "repetition_penalty": repetition_penalty,
52
+ "top_k": top_k,
53
+ "top_p": top_p,
54
+ "do_sample": do_sample,
55
+ "temperature": temperature
56
+ }
57
+
58
+ output, context, source = generate(input_, model_parameters)
59
+ sources_markup = ""
60
+
61
+ for url in source:
62
+ sources_markup += f'<a href="{url}" target="_blank">{url}</a><br>'
63
+
64
+ return output, sources_markup, context
65
+ # return output.strip(), sources_markup, context
66
+
67
+
68
+ def change_interactive(text):
69
+ if len(text) == 0:
70
+ return gr.update(interactive=True), gr.update(interactive=False)
71
+ return gr.update(interactive=True), gr.update(interactive=True)
72
+
73
+
74
+ def clear():
75
+ return (
76
+ None,
77
+ None,
78
+ None,
79
+ None,
80
+ gr.Slider(value=2.0),
81
+ gr.Slider(value=MAX_NEW_TOKENS),
82
+ gr.Slider(value=1.0),
83
+ gr.Slider(value=50),
84
+ gr.Slider(value=0.99),
85
+ gr.Checkbox(value=False),
86
+ gr.Slider(value=0.35),
87
+ )
88
+
89
+
90
+ def gradio_app():
91
+ with gr.Blocks(theme=theme) as demo:
92
+ with gr.Row():
93
+ # with gr.Column(scale=0.1):
94
+ # # gr.Image("rag_image.jpg", elem_id="flor-banner", scale=1, height=256, width=256, show_label=False, show_download_button = False, show_share_button = False)
95
+ with gr.Column():
96
+ gr.Markdown(
97
+ """# Demo de Retrieval-Augmented Generation per documents legals
98
+ 🔍 **Retrieval-Augmented Generation** (RAG) és una tecnologia d'IA que permet interrogar un repositori de documents amb preguntes
99
+ en llenguatge natural, i combina tècniques de recuperació d'informació avançades amb models generatius per redactar una resposta
100
+ fent servir només la informació existent en els documents del repositori.
101
+
102
+ 🎯 **Objectiu:** Aquest és un demostrador amb la normativa vigent publicada al Diari Oficial de la Generalitat de Catalunya, en el
103
+ repositori del EADOP (Entitat Autònoma del Diari Oficial i de Publicacions). Aquesta versió explora prop de 2000 documents en català,
104
+ i genera la resposta fent servir el model Salamandra-7b-aligned-EADOP, el model BSC-LT/salamandra-7b-instruct alineat amb el dataset de alinia/EADOP-RAG-out-of-domain.
105
+
106
+ ⚠️ **Advertencies**: Aquesta versió és experimental. El contingut generat per aquest model no està supervisat i pot ser incorrecte.
107
+ Si us plau, tingueu-ho en compte quan exploreu aquest recurs. El model en inferencia asociat a aquesta demo de desenvolupament no funciona continuament. Si vol fer proves,
108
+ contacteu amb nosaltres a Langtech.
109
+
110
+
111
+ 👀 **Mes informació en els informes de: ** [RAG](https://drive.google.com/file/d/11MgXQXAxfhkqbrx8syrKtmBrNP_6Qhx9/view?usp=sharing) i [Alineació](https://drive.google.com/file/d/1VUqHKO-gDmgMozK-Al83a2kh4Fr70pHh/view?usp=sharing) en pdf (ànglés).
112
+ """
113
+ )
114
+ with gr.Row(equal_height=True):
115
+ with gr.Column(variant="panel"):
116
+ input_ = Textbox(
117
+ lines=11,
118
+ label="Input",
119
+ placeholder="Quina és la finalitat del Servei Meteorològic de Catalunya?",
120
+ # value = "Quina és la finalitat del Servei Meteorològic de Catalunya?"
121
+ )
122
+ with gr.Row(variant="panel"):
123
+ clear_btn = Button(
124
+ "Clear",
125
+ )
126
+ submit_btn = Button("Submit", variant="primary", interactive=False)
127
+
128
+ with gr.Row(variant="panel"):
129
+ with gr.Accordion("Model parameters", open=False, visible=SHOW_MODEL_PARAMETERS_IN_UI):
130
+ num_chunks = Slider(
131
+ minimum=1,
132
+ maximum=6,
133
+ step=1,
134
+ value=2,
135
+ label="Number of chunks"
136
+ )
137
+ max_new_tokens = Slider(
138
+ minimum=50,
139
+ maximum=2000,
140
+ step=1,
141
+ value=MAX_NEW_TOKENS,
142
+ label="Max tokens"
143
+ )
144
+ repetition_penalty = Slider(
145
+ minimum=0.1,
146
+ maximum=2.0,
147
+ step=0.1,
148
+ value=1.0,
149
+ label="Repetition penalty"
150
+ )
151
+ top_k = Slider(
152
+ minimum=1,
153
+ maximum=100,
154
+ step=1,
155
+ value=50,
156
+ label="Top k"
157
+ )
158
+ top_p = Slider(
159
+ minimum=0.01,
160
+ maximum=0.99,
161
+ value=0.99,
162
+ label="Top p"
163
+ )
164
+ do_sample = Checkbox(
165
+ value=False,
166
+ label="Do sample"
167
+ )
168
+ temperature = Slider(
169
+ minimum=0.1,
170
+ maximum=1,
171
+ value=0.35,
172
+ label="Temperature"
173
+ )
174
+
175
+ parameters_compontents = [num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature]
176
+
177
+ with gr.Column(variant="panel"):
178
+ output = Textbox(
179
+ lines=10,
180
+ label="Output",
181
+ interactive=False,
182
+ show_copy_button=True
183
+ )
184
+ with gr.Accordion("Sources and context:", open=False):
185
+ source_context = gr.Markdown(
186
+ label="Sources",
187
+ show_label=False,
188
+ )
189
+ with gr.Accordion("See full context evaluation:", open=False):
190
+ context_evaluation = gr.Markdown(
191
+ label="Full context",
192
+ show_label=False,
193
+ # interactive=False,
194
+ # autoscroll=False,
195
+ # show_copy_button=True
196
+ )
197
+
198
+
199
+ input_.change(
200
+ fn=change_interactive,
201
+ inputs=[input_],
202
+ outputs=[clear_btn, submit_btn],
203
+ api_name=False,
204
+ )
205
+
206
+ input_.change(
207
+ fn=None,
208
+ inputs=[input_],
209
+ api_name=False,
210
+ js="""(i, m) => {
211
+ document.getElementById('inputlenght').textContent = i.length + ' '
212
+ document.getElementById('inputlenght').style.color = (i.length > m) ? "#ef4444" : "";
213
+ }""",
214
+ )
215
+
216
+ clear_btn.click(
217
+ fn=clear,
218
+ inputs=[],
219
+ outputs=[input_, output, source_context, context_evaluation] + parameters_compontents,
220
+ queue=False,
221
+ api_name=False
222
+ )
223
+
224
+ submit_btn.click(
225
+ fn=submit_input,
226
+ inputs=[input_]+ parameters_compontents,
227
+ outputs=[output, source_context, context_evaluation],
228
+ api_name="get-results"
229
+ )
230
+
231
+ with gr.Row():
232
+ with gr.Column(scale=0.5):
233
+ gr.Examples(
234
+ examples=[
235
+ ["""Qui va crear la guerra de les Galaxies ?"""],
236
+ ],
237
+ inputs=input_,
238
+ outputs=[output, source_context, context_evaluation],
239
+ fn=submit_input,
240
+ )
241
+
242
+ demo.launch(show_api=True)
243
+
244
+
245
+ if __name__ == "__main__":
246
+ gradio_app()
handler.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ class ContentHandler():
4
+ content_type = "application/json"
5
+ accepts = "application/json"
6
+
7
+ def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
8
+ input_str = json.dumps({'inputs': prompt, 'parameters': model_kwargs})
9
+ return input_str.encode('utf-8')
10
+
11
+ def transform_output(self, output: bytes) -> str:
12
+ response_json = json.loads(output.read().decode("utf-8"))
13
+ return response_json[0]["generated_text"]
14
+
input_reader.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from llama_index.core.constants import DEFAULT_CHUNK_OVERLAP, DEFAULT_CHUNK_SIZE
4
+ from llama_index.core.readers import SimpleDirectoryReader
5
+ from llama_index.core.schema import Document
6
+ from llama_index.core import Settings
7
+
8
+
9
+ class InputReader:
10
+ def __init__(self, input_dir: str) -> None:
11
+ self.reader = SimpleDirectoryReader(input_dir=input_dir)
12
+
13
+ def parse_documents(
14
+ self,
15
+ show_progress: bool = True,
16
+ chunk_size: int = DEFAULT_CHUNK_SIZE,
17
+ chunk_overlap: int = DEFAULT_CHUNK_OVERLAP,
18
+ ) -> List[Document]:
19
+ Settings.chunk_size = chunk_size
20
+ Settings.chunk_overlap = chunk_overlap
21
+ documents = self.reader.load_data(show_progress=show_progress)
22
+ return documents
rag.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import requests
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
+ import torch
6
+ from openai import OpenAI
7
+ from huggingface_hub import snapshot_download, InferenceClient
8
+
9
+ from langchain_community.vectorstores import FAISS
10
+ from langchain_community.embeddings import HuggingFaceEmbeddings
11
+
12
+
13
+ class RAG:
14
+ NO_ANSWER_MESSAGE: str = "Ho sento, no he pogut respondre la teva pregunta."
15
+
16
+ #vectorstore = "index-intfloat_multilingual-e5-small-500-100-CA-ES" # mixed
17
+ #vectorstore = "vectorestore" # CA only
18
+ #vectorstore = "index-BAAI_bge-m3-1500-200-recursive_splitter-CA_ES_UE"
19
+
20
+ def __init__(self, vs_hf_repo_path, vectorstore_path, hf_token, embeddings_model, model_name, rerank_model, rerank_number_contexts):
21
+ self.vs_hf_repo_path = vs_hf_repo_path
22
+ self.vectorstore_path=vectorstore_path
23
+ self.model_name = model_name
24
+ self.hf_token = hf_token
25
+ self.rerank_model = rerank_model
26
+ self.rerank_number_contexts = rerank_number_contexts
27
+
28
+ # load vectore store
29
+ hf_vectorstore = snapshot_download(repo_id=vs_hf_repo_path)
30
+
31
+ embeddings = HuggingFaceEmbeddings(model_name=embeddings_model, model_kwargs={'device': 'cpu'})
32
+ self.vectore_store = FAISS.load_local(hf_vectorstore, embeddings, allow_dangerous_deserialization=True)
33
+ # self.vectore_store = FAISS.load_local(self.vectorstore_path, embeddings, allow_dangerous_deserialization=True)#, allow_dangerous_deserialization=True)
34
+ logging.info("RAG loaded!")
35
+ logging.info( self.vectore_store)
36
+
37
+ def rerank_contexts(self, instruction, contexts, number_of_contexts=1):
38
+ """
39
+ Rerank the contexts based on their relevance to the given instruction.
40
+ """
41
+
42
+ rerank_model = self.rerank_model
43
+
44
+
45
+ tokenizer = AutoTokenizer.from_pretrained(rerank_model)
46
+ model = AutoModelForSequenceClassification.from_pretrained(rerank_model)
47
+
48
+ def get_score(query, passage):
49
+ """Calculate the relevance score of a passage with respect to a query."""
50
+
51
+
52
+ inputs = tokenizer(query, passage, return_tensors='pt', truncation=True, padding=True, max_length=512)
53
+
54
+
55
+ with torch.no_grad():
56
+ outputs = model(**inputs)
57
+
58
+
59
+ logits = outputs.logits
60
+
61
+
62
+ score = logits.view(-1, ).float()
63
+
64
+
65
+ return score
66
+
67
+ scores = [get_score(instruction, c[0].page_content) for c in contexts]
68
+ combined = list(zip(contexts, scores))
69
+ sorted_combined = sorted(combined, key=lambda x: x[1], reverse=True)
70
+ sorted_texts, _ = zip(*sorted_combined)
71
+
72
+ return sorted_texts[:number_of_contexts]
73
+
74
+ def get_context(self, instruction, number_of_contexts=2):
75
+ """Retrieve the most relevant contexts for a given instruction."""
76
+ logging.info("RETRIEVE DOCUMENTS")
77
+ documentos = self.vectore_store.similarity_search_with_score(instruction, k=self.rerank_number_contexts)
78
+ # logging.info(documentos)
79
+ logging.info("RERANK DOCUMENTS")
80
+ documentos = self.rerank_contexts(instruction, documentos, number_of_contexts=number_of_contexts)
81
+ # logging.info(documentos)
82
+ print("Reranked documents")
83
+ return documentos
84
+
85
+ def predict_dolly(self, instruction, context, model_parameters):
86
+
87
+ api_key = os.getenv("HF_TOKEN")
88
+
89
+
90
+ headers = {
91
+ "Accept" : "application/json",
92
+ "Authorization": f"Bearer {api_key}",
93
+ "Content-Type": "application/json"
94
+ }
95
+
96
+ query = f"### Instruction\n{instruction}\n\n### Context\n{context}\n\n### Answer\n "
97
+ #prompt = "You are a helpful assistant. Answer the question using only the context you are provided with. If it is not possible to do it with the context, just say 'I can't answer'. <|endoftext|>"
98
+
99
+
100
+ payload = {
101
+ "inputs": query,
102
+ "parameters": model_parameters
103
+ }
104
+
105
+ response = requests.post(self.model_name, headers=headers, json=payload)
106
+
107
+ return response.json()[0]["generated_text"].split("###")[-1][8:]
108
+
109
+ def predict_completion(self, instruction, context, model_parameters):
110
+
111
+ client = OpenAI(
112
+ base_url=os.getenv("MODEL"),
113
+ api_key=os.getenv("HF_TOKEN")
114
+ )
115
+
116
+ query = f"Context:\n{context}\n\nQuestion:\n{instruction}"
117
+
118
+ chat_completion = client.chat.completions.create(
119
+ model="tgi",
120
+ messages=[
121
+ {"role": "user", "content": query}
122
+ ],
123
+ temperature=model_parameters["temperature"],
124
+ max_tokens=model_parameters["max_new_tokens"],
125
+ stream=False,
126
+ stop=["<|im_end|>"],
127
+ extra_body = {
128
+ "presence_penalty": model_parameters["repetition_penalty"] - 2,
129
+ "do_sample": False
130
+ }
131
+ )
132
+
133
+ response = chat_completion.choices[0].message.content
134
+
135
+ return response
136
+
137
+
138
+ def beautiful_context(self, docs):
139
+
140
+ text_context = ""
141
+
142
+ full_context = ""
143
+ source_context = []
144
+ for doc in docs:
145
+ # print("="*100)
146
+ # logging.info(doc)
147
+ text_context += doc[0].page_content
148
+ full_context += doc[0].page_content + "\n"
149
+ full_context += doc[0].metadata["title"] + "\n\n"
150
+ full_context += doc[0].metadata["url"] + "\n\n"
151
+ source_context.append(doc[0].metadata["url"])
152
+
153
+ return text_context, full_context, source_context
154
+
155
+ def get_response(self, prompt: str, model_parameters: dict) -> str:
156
+ try:
157
+ docs = self.get_context(prompt, model_parameters["NUM_CHUNKS"])
158
+ text_context, full_context, source = self.beautiful_context(docs)
159
+ print("#"*100)
160
+ logging.info("text_context")
161
+ logging.info(text_context)
162
+
163
+ print("#"*100)
164
+ logging.info("full context")
165
+ logging.info(full_context)
166
+
167
+ print("#"*100)
168
+ logging.info("source")
169
+ logging.info(source)
170
+
171
+ del model_parameters["NUM_CHUNKS"]
172
+
173
+ response = self.predict_completion(prompt, text_context, model_parameters)
174
+
175
+ if not response:
176
+ return self.NO_ANSWER_MESSAGE
177
+
178
+ return response, full_context, source
179
+ except Exception as err:
180
+ print(err)
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ huggingface-hub
3
+ python-dotenv
4
+ llama-index
5
+ llama-index-embeddings-huggingface
6
+ llama-index-llms-huggingface
7
+ sentence-transformers
8
+ langchain
9
+ faiss-cpu
10
+ aina-gradio-theme
11
+
12
+ langchain-community
13
+ langchain-core
14
+ openai
utils.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import warnings
3
+
4
+ from dotenv import load_dotenv
5
+
6
+
7
+ from rag import RAG
8
+
9
+ USER_INPUT = 100
10
+
11
+
12
+ def setup():
13
+ load_dotenv()
14
+ warnings.filterwarnings("ignore")
15
+
16
+ logging.addLevelName(USER_INPUT, "USER_INPUT")
17
+ logging.basicConfig(format="[%(levelname)s]: %(message)s", level=logging.INFO)
18
+
19
+
20
+ def interactive(model: RAG):
21
+ logging.info("Write `exit` when you want to stop the model.")
22
+ print()
23
+
24
+ query = ""
25
+ while query.lower() != "exit":
26
+ logging.log(USER_INPUT, "Write the query or `exit`:")
27
+ query = input()
28
+
29
+ if query.lower() == "exit":
30
+ break
31
+
32
+ response = model.get_response(query)
33
+ print(response, end="\n\n")