import os from pathlib import Path from typing import Any import pandas as pd from fastembed import SparseTextEmbedding, SparseEmbedding from sentence_transformers import SentenceTransformer from huggingface_hub import hf_hub_download from qdrant_client import QdrantClient from qdrant_client import models as qmodels VLLM_DTYPE = os.getenv("VLLM_DTYPE") DATA_PATH = Path(os.getenv("DATA_PATH")) DB_PATH = DATA_PATH / "db" HF_TOKEN = os.getenv("HF_TOKEN") RECREATE_DB = bool(os.getenv("RECREATE_DB", "False").lower() == "true") DATA_REPO = os.getenv("DATA_REPO") DATA_FILENAME = os.getenv("DATA_FILENAME") client = QdrantClient(path=str(DB_PATH)) collection_name = "knowledge_cards" dense_model_dims = 1024 dense_batch_size = 128 sparse_batch_size = 256 dense_encoder = SentenceTransformer( model_name_or_path="mixedbread-ai/mxbai-embed-large-v1", device="cuda", model_kwargs={"torch_dtype": VLLM_DTYPE}, ) sparse_encoder = SparseTextEmbedding(model_name="Qdrant/bm25", cuda=True) # Utils def convert_serialized_sparse_embeddings(sparse_dict: dict[str, float]): """Convert all dictionary keys to strings for PyArrow compatibility.""" return SparseEmbedding.from_dict({int(k): v for k, v in sparse_dict.items()}) def ingest_data(chunks: list[dict[str, Any]]): if client.collection_exists(collection_name) and RECREATE_DB: print("Recreating collection.", flush=True) client.delete_collection(collection_name) elif client.collection_exists(collection_name): print("Collection already exists, skipping ingestion.", flush=True) return print("Ingesting knowledge cards...", flush=True) client.create_collection( collection_name=collection_name, vectors_config={ "dense": qmodels.VectorParams( size=dense_model_dims, distance=qmodels.Distance.COSINE, ) }, sparse_vectors_config={ "sparse": qmodels.SparseVectorParams(modifier=qmodels.Modifier.IDF) }, ) # Generate embeddings chunk_texts = [chunk["text"] for chunk in chunks] dense_vectors = list( dense_encoder.encode( chunk_texts, batch_size=dense_batch_size, normalize_embeddings=True, ) ) sparse_vectors = list( sparse_encoder.embed(chunk_texts, batch_size=sparse_batch_size) ) # Upload to db client.upload_points( collection_name=collection_name, points=[ qmodels.PointStruct( id=idx, payload=chunk, vector={"dense": dense, "sparse": sparse.as_object()}, ) for idx, (chunk, dense, sparse) in enumerate( zip(chunks, dense_vectors, sparse_vectors) ) ], ) def ingest(): downloaded_path = hf_hub_download( repo_id=DATA_REPO, filename=DATA_FILENAME, token=HF_TOKEN, repo_type="dataset" ) print(f"Downloaded knowledge card dataset; path = {downloaded_path}", flush=True) chunk_df = pd.read_parquet(downloaded_path) chunks = chunk_df.to_dict(orient="records") ingest_data(chunks=chunks) print("Ingestion is finished.", flush=True) if __name__ == "__main__": ingest()