Spaces:
Paused
Paused
import os | |
from pathlib import Path | |
from typing import Any | |
import pandas as pd | |
from fastembed import SparseTextEmbedding, SparseEmbedding | |
from sentence_transformers import SentenceTransformer | |
from huggingface_hub import hf_hub_download | |
from qdrant_client import QdrantClient | |
from qdrant_client import models as qmodels | |
VLLM_DTYPE = os.getenv("VLLM_DTYPE") | |
DATA_PATH = Path(os.getenv("DATA_PATH")) | |
DB_PATH = DATA_PATH / "db" | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
RECREATE_DB = bool(os.getenv("RECREATE_DB", "False").lower() == "true") | |
DATA_REPO = os.getenv("DATA_REPO") | |
DATA_FILENAME = os.getenv("DATA_FILENAME") | |
client = QdrantClient(path=str(DB_PATH)) | |
collection_name = "knowledge_cards" | |
dense_model_dims = 1024 | |
dense_batch_size = 128 | |
sparse_batch_size = 256 | |
dense_encoder = SentenceTransformer( | |
model_name_or_path="mixedbread-ai/mxbai-embed-large-v1", | |
device="cuda", | |
model_kwargs={"torch_dtype": VLLM_DTYPE}, | |
) | |
sparse_encoder = SparseTextEmbedding(model_name="Qdrant/bm25", cuda=True) | |
# Utils | |
def convert_serialized_sparse_embeddings(sparse_dict: dict[str, float]): | |
"""Convert all dictionary keys to strings for PyArrow compatibility.""" | |
return SparseEmbedding.from_dict({int(k): v for k, v in sparse_dict.items()}) | |
def ingest_data(chunks: list[dict[str, Any]]): | |
if client.collection_exists(collection_name) and RECREATE_DB: | |
print("Recreating collection.", flush=True) | |
client.delete_collection(collection_name) | |
elif client.collection_exists(collection_name): | |
print("Collection already exists, skipping ingestion.", flush=True) | |
return | |
print("Ingesting knowledge cards...", flush=True) | |
client.create_collection( | |
collection_name=collection_name, | |
vectors_config={ | |
"dense": qmodels.VectorParams( | |
size=dense_model_dims, | |
distance=qmodels.Distance.COSINE, | |
) | |
}, | |
sparse_vectors_config={ | |
"sparse": qmodels.SparseVectorParams(modifier=qmodels.Modifier.IDF) | |
}, | |
) | |
# Generate embeddings | |
chunk_texts = [chunk["text"] for chunk in chunks] | |
dense_vectors = list( | |
dense_encoder.encode( | |
chunk_texts, | |
batch_size=dense_batch_size, | |
normalize_embeddings=True, | |
) | |
) | |
sparse_vectors = list( | |
sparse_encoder.embed(chunk_texts, batch_size=sparse_batch_size) | |
) | |
# Upload to db | |
client.upload_points( | |
collection_name=collection_name, | |
points=[ | |
qmodels.PointStruct( | |
id=idx, | |
payload=chunk, | |
vector={"dense": dense, "sparse": sparse.as_object()}, | |
) | |
for idx, (chunk, dense, sparse) in enumerate( | |
zip(chunks, dense_vectors, sparse_vectors) | |
) | |
], | |
) | |
def ingest(): | |
downloaded_path = hf_hub_download( | |
repo_id=DATA_REPO, filename=DATA_FILENAME, token=HF_TOKEN, repo_type="dataset" | |
) | |
print(f"Downloaded knowledge card dataset; path = {downloaded_path}", flush=True) | |
chunk_df = pd.read_parquet(downloaded_path) | |
chunks = chunk_df.to_dict(orient="records") | |
ingest_data(chunks=chunks) | |
print("Ingestion is finished.", flush=True) | |
if __name__ == "__main__": | |
ingest() | |