sdmrec-docker / ingest.py
Oleh Kuznetsov
feat(rec): Finalize recommendations (almost done)
bdaca7e
raw
history blame
3.28 kB
import os
from pathlib import Path
from typing import Any
import pandas as pd
from fastembed import SparseTextEmbedding, SparseEmbedding
from sentence_transformers import SentenceTransformer
from huggingface_hub import hf_hub_download
from qdrant_client import QdrantClient
from qdrant_client import models as qmodels
VLLM_DTYPE = os.getenv("VLLM_DTYPE")
DATA_PATH = Path(os.getenv("DATA_PATH"))
DB_PATH = DATA_PATH / "db"
HF_TOKEN = os.getenv("HF_TOKEN")
RECREATE_DB = bool(os.getenv("RECREATE_DB", "False").lower() == "true")
DATA_REPO = os.getenv("DATA_REPO")
DATA_FILENAME = os.getenv("DATA_FILENAME")
client = QdrantClient(path=str(DB_PATH))
collection_name = "knowledge_cards"
dense_model_dims = 1024
dense_batch_size = 128
sparse_batch_size = 256
dense_encoder = SentenceTransformer(
model_name_or_path="mixedbread-ai/mxbai-embed-large-v1",
device="cuda",
model_kwargs={"torch_dtype": VLLM_DTYPE},
)
sparse_encoder = SparseTextEmbedding(model_name="Qdrant/bm25", cuda=True)
# Utils
def convert_serialized_sparse_embeddings(sparse_dict: dict[str, float]):
"""Convert all dictionary keys to strings for PyArrow compatibility."""
return SparseEmbedding.from_dict({int(k): v for k, v in sparse_dict.items()})
def ingest_data(chunks: list[dict[str, Any]]):
if client.collection_exists(collection_name) and RECREATE_DB:
print("Recreating collection.", flush=True)
client.delete_collection(collection_name)
elif client.collection_exists(collection_name):
print("Collection already exists, skipping ingestion.", flush=True)
return
print("Ingesting knowledge cards...", flush=True)
client.create_collection(
collection_name=collection_name,
vectors_config={
"dense": qmodels.VectorParams(
size=dense_model_dims,
distance=qmodels.Distance.COSINE,
)
},
sparse_vectors_config={
"sparse": qmodels.SparseVectorParams(modifier=qmodels.Modifier.IDF)
},
)
# Generate embeddings
chunk_texts = [chunk["text"] for chunk in chunks]
dense_vectors = list(
dense_encoder.encode(
chunk_texts,
batch_size=dense_batch_size,
normalize_embeddings=True,
)
)
sparse_vectors = list(
sparse_encoder.embed(chunk_texts, batch_size=sparse_batch_size)
)
# Upload to db
client.upload_points(
collection_name=collection_name,
points=[
qmodels.PointStruct(
id=idx,
payload=chunk,
vector={"dense": dense, "sparse": sparse.as_object()},
)
for idx, (chunk, dense, sparse) in enumerate(
zip(chunks, dense_vectors, sparse_vectors)
)
],
)
def ingest():
downloaded_path = hf_hub_download(
repo_id=DATA_REPO, filename=DATA_FILENAME, token=HF_TOKEN, repo_type="dataset"
)
print(f"Downloaded knowledge card dataset; path = {downloaded_path}", flush=True)
chunk_df = pd.read_parquet(downloaded_path)
chunks = chunk_df.to_dict(orient="records")
ingest_data(chunks=chunks)
print("Ingestion is finished.", flush=True)
if __name__ == "__main__":
ingest()