CSRD_reports_analysis / lib /ingestion_chroma.py
Matteo-CNPPS's picture
test_commit
2ce0b48
import chromadb
from chromadb.utils import embedding_functions
from tqdm import tqdm
import time
####################################################################################################################################
############################################# GLOBAL INGESTION #####################################################################
####################################################################################################################################
def prepare_chunks_for_ingestion(df):
"""
Specialisé pour les fichiers RSE
"""
chunks = list(df.full_chunk)
metadatas = [
{
"source": str(source),
"chunk_size": str(chunk_size),
}
for source, chunk_size in zip(list(df.source), list(df.chunk_size))
]
return chunks, metadatas
###################################################################################################################################
def ingest_chunks(df=None, batch_size=100, create_collection=False, chroma_data_path="./chroma_data/", embedding_model="intfloat/multilingual-e5-large", collection_name=None):
"""
Adds to a RAG database from a dataframe with metadata and text already read. And returns the question answering pipeline.
Documents already chunked !
Custom file slicing from self-care data.
Parameters:
- df the dataframe of chunked docs with their metadata and text
- batch_size (optional)
Returns:
- collection: the resulting chroma collection
- duration: the list of duration of batch ingestion
"""
print("Modèle d'embedding choisi: ", embedding_model)
print("Collection où ingérer: ", collection_name)
# La collection du vector store est censée déjà exister.
client = chromadb.PersistentClient(path=chroma_data_path)
embedding_func = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=embedding_model)
if create_collection:
collection = client.create_collection(
name=collection_name,
embedding_function=embedding_func,
metadata={"hnsw:space": "cosine"},
)
next_id = 0
else:
collection = client.get_collection(name=collection_name, embedding_function=embedding_func)
print("Computing next chroma id. Please wait a few minutes...")
next_id = compute_next_id_chroma(chroma_data_path, collection_name)
print("Préparation des métadatas des chunks :")
documents, metadatas = prepare_chunks_for_ingestion(df)
# batch adding to do it faster
durations = []
total_batches = len(documents)/batch_size
initialisation=True
for i in tqdm(range(0, len(documents), batch_size)):
# print(f"Processing batch number {i/batch_size} of {total_batches}...")
if initialisation:
print(f"Processing first batch of {total_batches}.")
print("This can take 10-15 mins if this is the first time the model is loaded. Please wait...")
initialisation=False
with open("ingesting.log", "a") as file:
file.write(f"Processing batch number {i/batch_size} of {total_batches}..." +"\n")
batch_documents = documents[i:i+batch_size]
batch_ids = [f"id{j}" for j in range(next_id+i, next_id+i+len(batch_documents))]
batch_metadatas = metadatas[i:i+batch_size]
start_time = time.time() # start measuring execution time
collection.add(
documents=batch_documents,
ids=batch_ids, # [f"id{i}" for i in range(len(documents))],
metadatas=batch_metadatas
)
end_time = time.time() # end measuring execution time
with open("ingesting.log", "a") as file:
file.write(f"Done. Collection adding time: {end_time-start_time}"+"\n")
durations.append(end_time-start_time) # store execution times per batch
return collection, durations
###################################################################################################################################
def clean_rag_collection(collname,chroma_data_path):
""" Removes the old ollection for the RAG to ingest data new.
"""
client = chromadb.PersistentClient(path=chroma_data_path)
res = client.delete_collection(name=collname)
return res
###################################################################################################################################
def retrieve_info_from_db(prompt: str, entreprise=None):
EMBED_MODEL = 'intfloat/multilingual-e5-large'
collection_name = "RSE_CSRD_REPORTS_TEST"
# création du client
client = chromadb.PersistentClient(path="./data/chroma_data/")
# chargement du modèle d'embedding permettant le calcul de proximité sémantique
embedding_func = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name=EMBED_MODEL
)
collection = client.get_collection(name=collection_name, embedding_function=embedding_func)
if entreprise is not None:
# requête
query_results = collection.query(
query_texts=[prompt],
n_results=3,
where={'source': entreprise}
)
else:
# requête
query_results = collection.query(
query_texts=[prompt],
n_results=3
)
return query_results