sdmrec-docker / ingest.py
Oleh Kuznetsov
feat(rec): Finalize recommendations (almost done)
bdaca7e
import os
from pathlib import Path
from typing import Any
import pandas as pd
from fastembed import SparseTextEmbedding, SparseEmbedding
from sentence_transformers import SentenceTransformer
from huggingface_hub import hf_hub_download
from qdrant_client import QdrantClient
from qdrant_client import models as qmodels
VLLM_DTYPE = os.getenv("VLLM_DTYPE")
DATA_PATH = Path(os.getenv("DATA_PATH"))
DB_PATH = DATA_PATH / "db"
HF_TOKEN = os.getenv("HF_TOKEN")
RECREATE_DB = bool(os.getenv("RECREATE_DB", "False").lower() == "true")
DATA_REPO = os.getenv("DATA_REPO")
DATA_FILENAME = os.getenv("DATA_FILENAME")
client = QdrantClient(path=str(DB_PATH))
collection_name = "knowledge_cards"
dense_model_dims = 1024
dense_batch_size = 128
sparse_batch_size = 256
dense_encoder = SentenceTransformer(
model_name_or_path="mixedbread-ai/mxbai-embed-large-v1",
device="cuda",
model_kwargs={"torch_dtype": VLLM_DTYPE},
)
sparse_encoder = SparseTextEmbedding(model_name="Qdrant/bm25", cuda=True)
# Utils
def convert_serialized_sparse_embeddings(sparse_dict: dict[str, float]):
"""Convert all dictionary keys to strings for PyArrow compatibility."""
return SparseEmbedding.from_dict({int(k): v for k, v in sparse_dict.items()})
def ingest_data(chunks: list[dict[str, Any]]):
if client.collection_exists(collection_name) and RECREATE_DB:
print("Recreating collection.", flush=True)
client.delete_collection(collection_name)
elif client.collection_exists(collection_name):
print("Collection already exists, skipping ingestion.", flush=True)
return
print("Ingesting knowledge cards...", flush=True)
client.create_collection(
collection_name=collection_name,
vectors_config={
"dense": qmodels.VectorParams(
size=dense_model_dims,
distance=qmodels.Distance.COSINE,
)
},
sparse_vectors_config={
"sparse": qmodels.SparseVectorParams(modifier=qmodels.Modifier.IDF)
},
)
# Generate embeddings
chunk_texts = [chunk["text"] for chunk in chunks]
dense_vectors = list(
dense_encoder.encode(
chunk_texts,
batch_size=dense_batch_size,
normalize_embeddings=True,
)
)
sparse_vectors = list(
sparse_encoder.embed(chunk_texts, batch_size=sparse_batch_size)
)
# Upload to db
client.upload_points(
collection_name=collection_name,
points=[
qmodels.PointStruct(
id=idx,
payload=chunk,
vector={"dense": dense, "sparse": sparse.as_object()},
)
for idx, (chunk, dense, sparse) in enumerate(
zip(chunks, dense_vectors, sparse_vectors)
)
],
)
def ingest():
downloaded_path = hf_hub_download(
repo_id=DATA_REPO, filename=DATA_FILENAME, token=HF_TOKEN, repo_type="dataset"
)
print(f"Downloaded knowledge card dataset; path = {downloaded_path}", flush=True)
chunk_df = pd.read_parquet(downloaded_path)
chunks = chunk_df.to_dict(orient="records")
ingest_data(chunks=chunks)
print("Ingestion is finished.", flush=True)
if __name__ == "__main__":
ingest()