Matteo-CNPPS commited on
Commit
2ce0b48
·
1 Parent(s): 8176958

test_commit

Browse files
.ipynb_checkpoints/app-checkpoint.py CHANGED
@@ -1,43 +1,95 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
 
26
- messages.append({"role": "user", "content": message})
27
 
28
- response = ""
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
 
 
 
41
 
42
 
43
  """
@@ -45,18 +97,6 @@ For information on how to customize the ChatInterface, peruse the gradio docs: h
45
  """
46
  demo = gr.ChatInterface(
47
  respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
  )
61
 
62
 
 
1
  import gradio as gr
2
+ from huggingface_hub import HfApiModel
3
+ import sys
4
+ if './lib' not in sys.path :
5
+ sys.path.append('./lib')
6
+ from ingestion_chroma import retrieve_info_from_db
7
 
8
+ ############################################################################################
9
+ ################################### TOOLS ##################################################
10
+ ############################################################################################
11
+
12
+ def find_key(data, target_key):
13
+ if isinstance(data, dict):
14
+ for key, value in data.items():
15
+ if key == target_key:
16
+ return value
17
+ else:
18
+ result = find_key(value, target_key)
19
+ if result is not None:
20
+ return result
21
+ return "Indicator not found"
22
+
23
+ ############################################################################################
24
+
25
+ class Chroma_retrieverTool(Tool):
26
+ name = "request"
27
+ description = "Using semantic similarity, retrieve the text from the knowledge base that has the embedding closest to the query."
28
+ inputs = {
29
+ "query": {
30
+ "type": "string",
31
+ "description": "The query to execute must be semantically close to the text to search. Use the affirmative form rather than a question.",
32
+ },
33
+ }
34
+ output_type = "string"
35
+
36
+ def forward(self, query: str) -> str:
37
+ assert isinstance(query, str), "The request needs to be a string."
38
+
39
+ query_results = retrieve_info_from_db(query)
40
+ str_result = "\nRetrieval texts : \n" + "".join([f"===== Text {str(i)} =====\n" + query_results['documents'][0][i] for i in range(len(query_results['documents'][0]))])
41
+
42
+ return str_result
43
+
44
+ ############################################################################################
45
 
46
+ class ESRS_info_tool(Tool):
47
+ name = "find_ESRS"
48
+ description = "Find ESRS description to help you to find what indicators the user want"
49
+ inputs = {
50
+ "indicator": {
51
+ "type": "string",
52
+ "description": "The indicator name. return the description of the indicator demanded.",
53
+ },
54
+ }
55
+ output_type = "string"
56
+
57
+ def forward(self, indicator: str) -> str:
58
+ assert isinstance(indicator, str), "The request needs to be a string."
59
 
60
+ with open('./data/dico_esrs.json') as json_data:
61
+ dico_esrs = json.load(json_data)
62
+
63
+ result = find_key(dico_esrs, indicator)
 
 
 
 
 
64
 
65
+ return result
66
+
67
+ ############################################################################################
68
+ ############################################################################################
69
+ ############################################################################################
70
 
71
+ model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct")
72
 
73
+ retriever_tool = Chroma_retrieverTool()
74
+ get_ESRS_info_tool = ESRS_info_tool()
75
+ agent = CodeAgent(
76
+ tools=[
77
+ get_ESRS_info_tool,
78
+ retriever_tool,
79
+ ],
80
+ model=model,
81
+ max_steps=10,
82
+ max_print_outputs_length=16000,
83
+ additional_authorized_imports=['pandas', 'matplotlib', 'datetime']
84
+ )
85
 
 
 
 
 
 
 
 
 
86
 
87
+ def respond(message):
88
+ system_prompt_added = """You are an expert in environmental and corporate social responsibility. You must respond to requests using the query function in the document database.
89
+ User's question : """
90
+ agent_output = agent.run(system_prompt_added+"""Find all informations about the ESRS E1–5: Energy consumption from fossil sources in Sartorius documents.""")
91
+
92
+ yield agent_output
93
 
94
 
95
  """
 
97
  """
98
  demo = gr.ChatInterface(
99
  respond,
 
 
 
 
 
 
 
 
 
 
 
 
100
  )
101
 
102
 
.ipynb_checkpoints/archive-checkpoint.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
+
4
+ """
5
+ For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
+ """
7
+ client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
+
9
+
10
+ def respond(
11
+ message,
12
+ history: list[tuple[str, str]],
13
+ system_message,
14
+ max_tokens,
15
+ temperature,
16
+ top_p,
17
+ ):
18
+ messages = [{"role": "system", "content": system_message}]
19
+
20
+ for val in history:
21
+ if val[0]:
22
+ messages.append({"role": "user", "content": val[0]})
23
+ if val[1]:
24
+ messages.append({"role": "assistant", "content": val[1]})
25
+
26
+ messages.append({"role": "user", "content": message})
27
+
28
+ response = ""
29
+
30
+ for message in client.chat_completion(
31
+ messages,
32
+ max_tokens=max_tokens,
33
+ stream=True,
34
+ temperature=temperature,
35
+ top_p=top_p,
36
+ ):
37
+ token = message.choices[0].delta.content
38
+
39
+ response += token
40
+ yield response
41
+
42
+
43
+ """
44
+ For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
+ """
46
+ demo = gr.ChatInterface(
47
+ respond,
48
+ additional_inputs=[
49
+ gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
+ gr.Slider(
53
+ minimum=0.1,
54
+ maximum=1.0,
55
+ value=0.95,
56
+ step=0.05,
57
+ label="Top-p (nucleus sampling)",
58
+ ),
59
+ ],
60
+ )
61
+
62
+
63
+ if __name__ == "__main__":
64
+ demo.launch()
.ipynb_checkpoints/requirements-checkpoint.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ huggingface_hub==0.25.2
2
+ chromadb==0.6.3
.ipynb_checkpoints/test-checkpoint.txt DELETED
@@ -1 +0,0 @@
1
- blablabla
 
 
app.py CHANGED
@@ -1,43 +1,95 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
 
26
- messages.append({"role": "user", "content": message})
27
 
28
- response = ""
 
 
 
 
 
 
 
 
 
 
 
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
 
 
 
41
 
42
 
43
  """
@@ -45,18 +97,6 @@ For information on how to customize the ChatInterface, peruse the gradio docs: h
45
  """
46
  demo = gr.ChatInterface(
47
  respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
  )
61
 
62
 
 
1
  import gradio as gr
2
+ from huggingface_hub import HfApiModel
3
+ import sys
4
+ if './lib' not in sys.path :
5
+ sys.path.append('./lib')
6
+ from ingestion_chroma import retrieve_info_from_db
7
 
8
+ ############################################################################################
9
+ ################################### TOOLS ##################################################
10
+ ############################################################################################
11
+
12
+ def find_key(data, target_key):
13
+ if isinstance(data, dict):
14
+ for key, value in data.items():
15
+ if key == target_key:
16
+ return value
17
+ else:
18
+ result = find_key(value, target_key)
19
+ if result is not None:
20
+ return result
21
+ return "Indicator not found"
22
+
23
+ ############################################################################################
24
+
25
+ class Chroma_retrieverTool(Tool):
26
+ name = "request"
27
+ description = "Using semantic similarity, retrieve the text from the knowledge base that has the embedding closest to the query."
28
+ inputs = {
29
+ "query": {
30
+ "type": "string",
31
+ "description": "The query to execute must be semantically close to the text to search. Use the affirmative form rather than a question.",
32
+ },
33
+ }
34
+ output_type = "string"
35
+
36
+ def forward(self, query: str) -> str:
37
+ assert isinstance(query, str), "The request needs to be a string."
38
+
39
+ query_results = retrieve_info_from_db(query)
40
+ str_result = "\nRetrieval texts : \n" + "".join([f"===== Text {str(i)} =====\n" + query_results['documents'][0][i] for i in range(len(query_results['documents'][0]))])
41
+
42
+ return str_result
43
+
44
+ ############################################################################################
45
 
46
+ class ESRS_info_tool(Tool):
47
+ name = "find_ESRS"
48
+ description = "Find ESRS description to help you to find what indicators the user want"
49
+ inputs = {
50
+ "indicator": {
51
+ "type": "string",
52
+ "description": "The indicator name. return the description of the indicator demanded.",
53
+ },
54
+ }
55
+ output_type = "string"
56
+
57
+ def forward(self, indicator: str) -> str:
58
+ assert isinstance(indicator, str), "The request needs to be a string."
59
 
60
+ with open('./data/dico_esrs.json') as json_data:
61
+ dico_esrs = json.load(json_data)
62
+
63
+ result = find_key(dico_esrs, indicator)
 
 
 
 
 
64
 
65
+ return result
66
+
67
+ ############################################################################################
68
+ ############################################################################################
69
+ ############################################################################################
70
 
71
+ model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct")
72
 
73
+ retriever_tool = Chroma_retrieverTool()
74
+ get_ESRS_info_tool = ESRS_info_tool()
75
+ agent = CodeAgent(
76
+ tools=[
77
+ get_ESRS_info_tool,
78
+ retriever_tool,
79
+ ],
80
+ model=model,
81
+ max_steps=10,
82
+ max_print_outputs_length=16000,
83
+ additional_authorized_imports=['pandas', 'matplotlib', 'datetime']
84
+ )
85
 
 
 
 
 
 
 
 
 
86
 
87
+ def respond(message):
88
+ system_prompt_added = """You are an expert in environmental and corporate social responsibility. You must respond to requests using the query function in the document database.
89
+ User's question : """
90
+ agent_output = agent.run(system_prompt_added+"""Find all informations about the ESRS E1–5: Energy consumption from fossil sources in Sartorius documents.""")
91
+
92
+ yield agent_output
93
 
94
 
95
  """
 
97
  """
98
  demo = gr.ChatInterface(
99
  respond,
 
 
 
 
 
 
 
 
 
 
 
 
100
  )
101
 
102
 
archive.py CHANGED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
+
4
+ """
5
+ For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
+ """
7
+ client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
+
9
+
10
+ def respond(
11
+ message,
12
+ history: list[tuple[str, str]],
13
+ system_message,
14
+ max_tokens,
15
+ temperature,
16
+ top_p,
17
+ ):
18
+ messages = [{"role": "system", "content": system_message}]
19
+
20
+ for val in history:
21
+ if val[0]:
22
+ messages.append({"role": "user", "content": val[0]})
23
+ if val[1]:
24
+ messages.append({"role": "assistant", "content": val[1]})
25
+
26
+ messages.append({"role": "user", "content": message})
27
+
28
+ response = ""
29
+
30
+ for message in client.chat_completion(
31
+ messages,
32
+ max_tokens=max_tokens,
33
+ stream=True,
34
+ temperature=temperature,
35
+ top_p=top_p,
36
+ ):
37
+ token = message.choices[0].delta.content
38
+
39
+ response += token
40
+ yield response
41
+
42
+
43
+ """
44
+ For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
+ """
46
+ demo = gr.ChatInterface(
47
+ respond,
48
+ additional_inputs=[
49
+ gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
+ gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
+ gr.Slider(
53
+ minimum=0.1,
54
+ maximum=1.0,
55
+ value=0.95,
56
+ step=0.05,
57
+ label="Top-p (nucleus sampling)",
58
+ ),
59
+ ],
60
+ )
61
+
62
+
63
+ if __name__ == "__main__":
64
+ demo.launch()
data/dico_esrs.json ADDED
The diff for this file is too large to render. See raw diff
 
lib/.ipynb_checkpoints/ingestion_chroma-checkpoint.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chromadb
2
+ from chromadb.utils import embedding_functions
3
+ from tqdm import tqdm
4
+ import time
5
+
6
+
7
+ ####################################################################################################################################
8
+ ############################################# GLOBAL INGESTION #####################################################################
9
+ ####################################################################################################################################
10
+ def prepare_chunks_for_ingestion(df):
11
+ """
12
+ Specialisé pour les fichiers RSE
13
+ """
14
+ chunks = list(df.full_chunk)
15
+ metadatas = [
16
+ {
17
+ "source": str(source),
18
+ "chunk_size": str(chunk_size),
19
+ }
20
+ for source, chunk_size in zip(list(df.source), list(df.chunk_size))
21
+ ]
22
+ return chunks, metadatas
23
+
24
+
25
+ ###################################################################################################################################
26
+ def ingest_chunks(df=None, batch_size=100, create_collection=False, chroma_data_path="./chroma_data/", embedding_model="intfloat/multilingual-e5-large", collection_name=None):
27
+ """
28
+ Adds to a RAG database from a dataframe with metadata and text already read. And returns the question answering pipeline.
29
+ Documents already chunked !
30
+ Custom file slicing from self-care data.
31
+ Parameters:
32
+ - df the dataframe of chunked docs with their metadata and text
33
+ - batch_size (optional)
34
+ Returns:
35
+ - collection: the resulting chroma collection
36
+ - duration: the list of duration of batch ingestion
37
+ """
38
+
39
+ print("Modèle d'embedding choisi: ", embedding_model)
40
+ print("Collection où ingérer: ", collection_name)
41
+ # La collection du vector store est censée déjà exister.
42
+ client = chromadb.PersistentClient(path=chroma_data_path)
43
+ embedding_func = embedding_functions.SentenceTransformerEmbeddingFunction(
44
+ model_name=embedding_model)
45
+
46
+ if create_collection:
47
+ collection = client.create_collection(
48
+ name=collection_name,
49
+ embedding_function=embedding_func,
50
+ metadata={"hnsw:space": "cosine"},
51
+ )
52
+ next_id = 0
53
+ else:
54
+ collection = client.get_collection(name=collection_name, embedding_function=embedding_func)
55
+ print("Computing next chroma id. Please wait a few minutes...")
56
+ next_id = compute_next_id_chroma(chroma_data_path, collection_name)
57
+ print("Préparation des métadatas des chunks :")
58
+ documents, metadatas = prepare_chunks_for_ingestion(df)
59
+ # batch adding to do it faster
60
+ durations = []
61
+ total_batches = len(documents)/batch_size
62
+ initialisation=True
63
+ for i in tqdm(range(0, len(documents), batch_size)):
64
+ # print(f"Processing batch number {i/batch_size} of {total_batches}...")
65
+ if initialisation:
66
+ print(f"Processing first batch of {total_batches}.")
67
+ print("This can take 10-15 mins if this is the first time the model is loaded. Please wait...")
68
+ initialisation=False
69
+ with open("ingesting.log", "a") as file:
70
+ file.write(f"Processing batch number {i/batch_size} of {total_batches}..." +"\n")
71
+ batch_documents = documents[i:i+batch_size]
72
+ batch_ids = [f"id{j}" for j in range(next_id+i, next_id+i+len(batch_documents))]
73
+ batch_metadatas = metadatas[i:i+batch_size]
74
+ start_time = time.time() # start measuring execution time
75
+ collection.add(
76
+ documents=batch_documents,
77
+ ids=batch_ids, # [f"id{i}" for i in range(len(documents))],
78
+ metadatas=batch_metadatas
79
+ )
80
+ end_time = time.time() # end measuring execution time
81
+ with open("ingesting.log", "a") as file:
82
+ file.write(f"Done. Collection adding time: {end_time-start_time}"+"\n")
83
+ durations.append(end_time-start_time) # store execution times per batch
84
+ return collection, durations
85
+
86
+
87
+ ###################################################################################################################################
88
+ def clean_rag_collection(collname,chroma_data_path):
89
+ """ Removes the old ollection for the RAG to ingest data new.
90
+ """
91
+ client = chromadb.PersistentClient(path=chroma_data_path)
92
+ res = client.delete_collection(name=collname)
93
+ return res
94
+
95
+
96
+ ###################################################################################################################################
97
+ def retrieve_info_from_db(prompt: str, entreprise=None):
98
+ EMBED_MODEL = 'intfloat/multilingual-e5-large'
99
+ collection_name = "RSE_CSRD_REPORTS_TEST"
100
+ # création du client
101
+ client = chromadb.PersistentClient(path="./data/chroma_data/")
102
+ # chargement du modèle d'embedding permettant le calcul de proximité sémantique
103
+ embedding_func = embedding_functions.SentenceTransformerEmbeddingFunction(
104
+ model_name=EMBED_MODEL
105
+ )
106
+ collection = client.get_collection(name=collection_name, embedding_function=embedding_func)
107
+ if entreprise is not None:
108
+ # requête
109
+ query_results = collection.query(
110
+ query_texts=[prompt],
111
+ n_results=3,
112
+ where={'source': entreprise}
113
+ )
114
+ else:
115
+ # requête
116
+ query_results = collection.query(
117
+ query_texts=[prompt],
118
+ n_results=3
119
+ )
120
+
121
+ return query_results
lib/ingestion_chroma.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chromadb
2
+ from chromadb.utils import embedding_functions
3
+ from tqdm import tqdm
4
+ import time
5
+
6
+
7
+ ####################################################################################################################################
8
+ ############################################# GLOBAL INGESTION #####################################################################
9
+ ####################################################################################################################################
10
+ def prepare_chunks_for_ingestion(df):
11
+ """
12
+ Specialisé pour les fichiers RSE
13
+ """
14
+ chunks = list(df.full_chunk)
15
+ metadatas = [
16
+ {
17
+ "source": str(source),
18
+ "chunk_size": str(chunk_size),
19
+ }
20
+ for source, chunk_size in zip(list(df.source), list(df.chunk_size))
21
+ ]
22
+ return chunks, metadatas
23
+
24
+
25
+ ###################################################################################################################################
26
+ def ingest_chunks(df=None, batch_size=100, create_collection=False, chroma_data_path="./chroma_data/", embedding_model="intfloat/multilingual-e5-large", collection_name=None):
27
+ """
28
+ Adds to a RAG database from a dataframe with metadata and text already read. And returns the question answering pipeline.
29
+ Documents already chunked !
30
+ Custom file slicing from self-care data.
31
+ Parameters:
32
+ - df the dataframe of chunked docs with their metadata and text
33
+ - batch_size (optional)
34
+ Returns:
35
+ - collection: the resulting chroma collection
36
+ - duration: the list of duration of batch ingestion
37
+ """
38
+
39
+ print("Modèle d'embedding choisi: ", embedding_model)
40
+ print("Collection où ingérer: ", collection_name)
41
+ # La collection du vector store est censée déjà exister.
42
+ client = chromadb.PersistentClient(path=chroma_data_path)
43
+ embedding_func = embedding_functions.SentenceTransformerEmbeddingFunction(
44
+ model_name=embedding_model)
45
+
46
+ if create_collection:
47
+ collection = client.create_collection(
48
+ name=collection_name,
49
+ embedding_function=embedding_func,
50
+ metadata={"hnsw:space": "cosine"},
51
+ )
52
+ next_id = 0
53
+ else:
54
+ collection = client.get_collection(name=collection_name, embedding_function=embedding_func)
55
+ print("Computing next chroma id. Please wait a few minutes...")
56
+ next_id = compute_next_id_chroma(chroma_data_path, collection_name)
57
+ print("Préparation des métadatas des chunks :")
58
+ documents, metadatas = prepare_chunks_for_ingestion(df)
59
+ # batch adding to do it faster
60
+ durations = []
61
+ total_batches = len(documents)/batch_size
62
+ initialisation=True
63
+ for i in tqdm(range(0, len(documents), batch_size)):
64
+ # print(f"Processing batch number {i/batch_size} of {total_batches}...")
65
+ if initialisation:
66
+ print(f"Processing first batch of {total_batches}.")
67
+ print("This can take 10-15 mins if this is the first time the model is loaded. Please wait...")
68
+ initialisation=False
69
+ with open("ingesting.log", "a") as file:
70
+ file.write(f"Processing batch number {i/batch_size} of {total_batches}..." +"\n")
71
+ batch_documents = documents[i:i+batch_size]
72
+ batch_ids = [f"id{j}" for j in range(next_id+i, next_id+i+len(batch_documents))]
73
+ batch_metadatas = metadatas[i:i+batch_size]
74
+ start_time = time.time() # start measuring execution time
75
+ collection.add(
76
+ documents=batch_documents,
77
+ ids=batch_ids, # [f"id{i}" for i in range(len(documents))],
78
+ metadatas=batch_metadatas
79
+ )
80
+ end_time = time.time() # end measuring execution time
81
+ with open("ingesting.log", "a") as file:
82
+ file.write(f"Done. Collection adding time: {end_time-start_time}"+"\n")
83
+ durations.append(end_time-start_time) # store execution times per batch
84
+ return collection, durations
85
+
86
+
87
+ ###################################################################################################################################
88
+ def clean_rag_collection(collname,chroma_data_path):
89
+ """ Removes the old ollection for the RAG to ingest data new.
90
+ """
91
+ client = chromadb.PersistentClient(path=chroma_data_path)
92
+ res = client.delete_collection(name=collname)
93
+ return res
94
+
95
+
96
+ ###################################################################################################################################
97
+ def retrieve_info_from_db(prompt: str, entreprise=None):
98
+ EMBED_MODEL = 'intfloat/multilingual-e5-large'
99
+ collection_name = "RSE_CSRD_REPORTS_TEST"
100
+ # création du client
101
+ client = chromadb.PersistentClient(path="./data/chroma_data/")
102
+ # chargement du modèle d'embedding permettant le calcul de proximité sémantique
103
+ embedding_func = embedding_functions.SentenceTransformerEmbeddingFunction(
104
+ model_name=EMBED_MODEL
105
+ )
106
+ collection = client.get_collection(name=collection_name, embedding_function=embedding_func)
107
+ if entreprise is not None:
108
+ # requête
109
+ query_results = collection.query(
110
+ query_texts=[prompt],
111
+ n_results=3,
112
+ where={'source': entreprise}
113
+ )
114
+ else:
115
+ # requête
116
+ query_results = collection.query(
117
+ query_texts=[prompt],
118
+ n_results=3
119
+ )
120
+
121
+ return query_results
requirements.txt CHANGED
@@ -1 +1,2 @@
1
- huggingface_hub==0.25.2
 
 
1
+ huggingface_hub==0.25.2
2
+ chromadb==0.6.3
test.txt DELETED
@@ -1 +0,0 @@
1
- blablabla