|
import chromadb |
|
from chromadb.utils import embedding_functions |
|
from tqdm import tqdm |
|
import time |
|
|
|
|
|
|
|
|
|
|
|
def prepare_chunks_for_ingestion(df): |
|
""" |
|
Specialisé pour les fichiers RSE |
|
""" |
|
chunks = list(df.full_chunk) |
|
metadatas = [ |
|
{ |
|
"source": str(source), |
|
"chunk_size": str(chunk_size), |
|
} |
|
for source, chunk_size in zip(list(df.source), list(df.chunk_size)) |
|
] |
|
return chunks, metadatas |
|
|
|
|
|
|
|
def ingest_chunks(df=None, batch_size=100, create_collection=False, chroma_data_path="./chroma_data/", embedding_model="intfloat/multilingual-e5-large", collection_name=None): |
|
""" |
|
Adds to a RAG database from a dataframe with metadata and text already read. And returns the question answering pipeline. |
|
Documents already chunked ! |
|
Custom file slicing from self-care data. |
|
Parameters: |
|
- df the dataframe of chunked docs with their metadata and text |
|
- batch_size (optional) |
|
Returns: |
|
- collection: the resulting chroma collection |
|
- duration: the list of duration of batch ingestion |
|
""" |
|
|
|
print("Modèle d'embedding choisi: ", embedding_model) |
|
print("Collection où ingérer: ", collection_name) |
|
|
|
client = chromadb.PersistentClient(path=chroma_data_path) |
|
embedding_func = embedding_functions.SentenceTransformerEmbeddingFunction( |
|
model_name=embedding_model) |
|
|
|
if create_collection: |
|
collection = client.create_collection( |
|
name=collection_name, |
|
embedding_function=embedding_func, |
|
metadata={"hnsw:space": "cosine"}, |
|
) |
|
next_id = 0 |
|
else: |
|
collection = client.get_collection(name=collection_name, embedding_function=embedding_func) |
|
print("Computing next chroma id. Please wait a few minutes...") |
|
next_id = compute_next_id_chroma(chroma_data_path, collection_name) |
|
print("Préparation des métadatas des chunks :") |
|
documents, metadatas = prepare_chunks_for_ingestion(df) |
|
|
|
durations = [] |
|
total_batches = len(documents)/batch_size |
|
initialisation=True |
|
for i in tqdm(range(0, len(documents), batch_size)): |
|
|
|
if initialisation: |
|
print(f"Processing first batch of {total_batches}.") |
|
print("This can take 10-15 mins if this is the first time the model is loaded. Please wait...") |
|
initialisation=False |
|
with open("ingesting.log", "a") as file: |
|
file.write(f"Processing batch number {i/batch_size} of {total_batches}..." +"\n") |
|
batch_documents = documents[i:i+batch_size] |
|
batch_ids = [f"id{j}" for j in range(next_id+i, next_id+i+len(batch_documents))] |
|
batch_metadatas = metadatas[i:i+batch_size] |
|
start_time = time.time() |
|
collection.add( |
|
documents=batch_documents, |
|
ids=batch_ids, |
|
metadatas=batch_metadatas |
|
) |
|
end_time = time.time() |
|
with open("ingesting.log", "a") as file: |
|
file.write(f"Done. Collection adding time: {end_time-start_time}"+"\n") |
|
durations.append(end_time-start_time) |
|
return collection, durations |
|
|
|
|
|
|
|
def clean_rag_collection(collname,chroma_data_path): |
|
""" Removes the old ollection for the RAG to ingest data new. |
|
""" |
|
client = chromadb.PersistentClient(path=chroma_data_path) |
|
res = client.delete_collection(name=collname) |
|
return res |
|
|
|
|
|
|
|
def retrieve_info_from_db(prompt: str, entreprise=None): |
|
EMBED_MODEL = 'intfloat/multilingual-e5-large' |
|
collection_name = "RSE_CSRD_REPORTS_TEST" |
|
|
|
client = chromadb.PersistentClient(path="./data/chroma_data/") |
|
|
|
embedding_func = embedding_functions.SentenceTransformerEmbeddingFunction( |
|
model_name=EMBED_MODEL |
|
) |
|
collection = client.get_collection(name=collection_name, embedding_function=embedding_func) |
|
if entreprise is not None: |
|
|
|
query_results = collection.query( |
|
query_texts=[prompt], |
|
n_results=3, |
|
where={'source': entreprise} |
|
) |
|
else: |
|
|
|
query_results = collection.query( |
|
query_texts=[prompt], |
|
n_results=3 |
|
) |
|
|
|
return query_results |