Spaces:
Sleeping
Sleeping
import chromadb | |
from chromadb.utils import embedding_functions | |
from tqdm import tqdm | |
import time | |
#################################################################################################################################### | |
############################################# GLOBAL INGESTION ##################################################################### | |
#################################################################################################################################### | |
def prepare_chunks_for_ingestion(df): | |
""" | |
Specialisé pour les fichiers RSE | |
""" | |
chunks = list(df.full_chunk) | |
metadatas = [ | |
{ | |
"source": str(source), | |
"chunk_size": str(chunk_size), | |
} | |
for source, chunk_size in zip(list(df.source), list(df.chunk_size)) | |
] | |
return chunks, metadatas | |
################################################################################################################################### | |
def ingest_chunks(df=None, batch_size=100, create_collection=False, chroma_data_path="./chroma_data/", embedding_model="intfloat/multilingual-e5-large", collection_name=None): | |
""" | |
Adds to a RAG database from a dataframe with metadata and text already read. And returns the question answering pipeline. | |
Documents already chunked ! | |
Custom file slicing from self-care data. | |
Parameters: | |
- df the dataframe of chunked docs with their metadata and text | |
- batch_size (optional) | |
Returns: | |
- collection: the resulting chroma collection | |
- duration: the list of duration of batch ingestion | |
""" | |
print("Modèle d'embedding choisi: ", embedding_model) | |
print("Collection où ingérer: ", collection_name) | |
# La collection du vector store est censée déjà exister. | |
client = chromadb.PersistentClient(path=chroma_data_path) | |
embedding_func = embedding_functions.SentenceTransformerEmbeddingFunction( | |
model_name=embedding_model) | |
if create_collection: | |
collection = client.create_collection( | |
name=collection_name, | |
embedding_function=embedding_func, | |
metadata={"hnsw:space": "cosine"}, | |
) | |
next_id = 0 | |
else: | |
collection = client.get_collection(name=collection_name, embedding_function=embedding_func) | |
print("Computing next chroma id. Please wait a few minutes...") | |
next_id = compute_next_id_chroma(chroma_data_path, collection_name) | |
print("Préparation des métadatas des chunks :") | |
documents, metadatas = prepare_chunks_for_ingestion(df) | |
# batch adding to do it faster | |
durations = [] | |
total_batches = len(documents)/batch_size | |
initialisation=True | |
for i in tqdm(range(0, len(documents), batch_size)): | |
# print(f"Processing batch number {i/batch_size} of {total_batches}...") | |
if initialisation: | |
print(f"Processing first batch of {total_batches}.") | |
print("This can take 10-15 mins if this is the first time the model is loaded. Please wait...") | |
initialisation=False | |
with open("ingesting.log", "a") as file: | |
file.write(f"Processing batch number {i/batch_size} of {total_batches}..." +"\n") | |
batch_documents = documents[i:i+batch_size] | |
batch_ids = [f"id{j}" for j in range(next_id+i, next_id+i+len(batch_documents))] | |
batch_metadatas = metadatas[i:i+batch_size] | |
start_time = time.time() # start measuring execution time | |
collection.add( | |
documents=batch_documents, | |
ids=batch_ids, # [f"id{i}" for i in range(len(documents))], | |
metadatas=batch_metadatas | |
) | |
end_time = time.time() # end measuring execution time | |
with open("ingesting.log", "a") as file: | |
file.write(f"Done. Collection adding time: {end_time-start_time}"+"\n") | |
durations.append(end_time-start_time) # store execution times per batch | |
return collection, durations | |
################################################################################################################################### | |
def clean_rag_collection(collname,chroma_data_path): | |
""" Removes the old ollection for the RAG to ingest data new. | |
""" | |
client = chromadb.PersistentClient(path=chroma_data_path) | |
res = client.delete_collection(name=collname) | |
return res | |
################################################################################################################################### | |
def retrieve_info_from_db(prompt: str, entreprise=None): | |
EMBED_MODEL = 'intfloat/multilingual-e5-large' | |
collection_name = "RSE_CSRD_REPORTS_TEST" | |
# création du client | |
client = chromadb.PersistentClient(path="./data/chroma_data/") | |
# chargement du modèle d'embedding permettant le calcul de proximité sémantique | |
embedding_func = embedding_functions.SentenceTransformerEmbeddingFunction( | |
model_name=EMBED_MODEL | |
) | |
collection = client.get_collection(name=collection_name, embedding_function=embedding_func) | |
if entreprise is not None: | |
# requête | |
query_results = collection.query( | |
query_texts=[prompt], | |
n_results=3, | |
where={'source': entreprise} | |
) | |
else: | |
# requête | |
query_results = collection.query( | |
query_texts=[prompt], | |
n_results=3 | |
) | |
return query_results |