File size: 3,276 Bytes
6e1997a
 
 
 
 
 
 
 
 
 
 
bdaca7e
 
6e1997a
 
 
 
bdaca7e
6e1997a
 
 
 
 
 
 
 
 
 
bdaca7e
 
 
6e1997a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
from pathlib import Path
from typing import Any

import pandas as pd
from fastembed import SparseTextEmbedding, SparseEmbedding
from sentence_transformers import SentenceTransformer
from huggingface_hub import hf_hub_download
from qdrant_client import QdrantClient
from qdrant_client import models as qmodels

VLLM_DTYPE = os.getenv("VLLM_DTYPE")

DATA_PATH = Path(os.getenv("DATA_PATH"))
DB_PATH = DATA_PATH / "db"
HF_TOKEN = os.getenv("HF_TOKEN")

RECREATE_DB = bool(os.getenv("RECREATE_DB", "False").lower() == "true")
DATA_REPO = os.getenv("DATA_REPO")
DATA_FILENAME = os.getenv("DATA_FILENAME")

client = QdrantClient(path=str(DB_PATH))
collection_name = "knowledge_cards"
dense_model_dims = 1024
dense_batch_size = 128
sparse_batch_size = 256

dense_encoder = SentenceTransformer(
    model_name_or_path="mixedbread-ai/mxbai-embed-large-v1",
    device="cuda",
    model_kwargs={"torch_dtype": VLLM_DTYPE},
)
sparse_encoder = SparseTextEmbedding(model_name="Qdrant/bm25", cuda=True)


# Utils
def convert_serialized_sparse_embeddings(sparse_dict: dict[str, float]):
    """Convert all dictionary keys to strings for PyArrow compatibility."""
    return SparseEmbedding.from_dict({int(k): v for k, v in sparse_dict.items()})


def ingest_data(chunks: list[dict[str, Any]]):
    if client.collection_exists(collection_name) and RECREATE_DB:
        print("Recreating collection.", flush=True)
        client.delete_collection(collection_name)
    elif client.collection_exists(collection_name):
        print("Collection already exists, skipping ingestion.", flush=True)
        return

    print("Ingesting knowledge cards...", flush=True)
    client.create_collection(
        collection_name=collection_name,
        vectors_config={
            "dense": qmodels.VectorParams(
                size=dense_model_dims,
                distance=qmodels.Distance.COSINE,
            )
        },
        sparse_vectors_config={
            "sparse": qmodels.SparseVectorParams(modifier=qmodels.Modifier.IDF)
        },
    )

    # Generate embeddings
    chunk_texts = [chunk["text"] for chunk in chunks]
    dense_vectors = list(
        dense_encoder.encode(
            chunk_texts,
            batch_size=dense_batch_size,
            normalize_embeddings=True,
        )
    )
    sparse_vectors = list(
        sparse_encoder.embed(chunk_texts, batch_size=sparse_batch_size)
    )

    # Upload to db
    client.upload_points(
        collection_name=collection_name,
        points=[
            qmodels.PointStruct(
                id=idx,
                payload=chunk,
                vector={"dense": dense, "sparse": sparse.as_object()},
            )
            for idx, (chunk, dense, sparse) in enumerate(
                zip(chunks, dense_vectors, sparse_vectors)
            )
        ],
    )


def ingest():
    downloaded_path = hf_hub_download(
        repo_id=DATA_REPO, filename=DATA_FILENAME, token=HF_TOKEN, repo_type="dataset"
    )
    print(f"Downloaded knowledge card dataset; path = {downloaded_path}", flush=True)
    chunk_df = pd.read_parquet(downloaded_path)
    chunks = chunk_df.to_dict(orient="records")
    ingest_data(chunks=chunks)
    print("Ingestion is finished.", flush=True)


if __name__ == "__main__":
    ingest()