File size: 4,872 Bytes
66340f1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import os
from qdrant_client import QdrantClient
from qdrant_client.http import models

# from dotenv import load_dotenv
from uuid import uuid4


# load_dotenv()

COLLECTION_NAME = os.getenv("COLLECTION_NAME")
COLLECTION_SIZE = os.getenv("COLLECTION_SIZE")
QDRANT_PORT = os.getenv("QDRANT_PORT")
QDRANT_HOST = os.getenv("QDRANT_HOST")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")


class QdrantManager:
    """
    A class for managing collectionsget_collection_info in the Qdrant database.

    Args:
        collection_name (str): The name of the collection to manage.
        collection_size (int): The maximum number of documents in the collection.
        port (int): The port number for the Qdrant API.
        host (str): The hostname or IP address for the Qdrant server.
        api_key (str): The API key for authenticating with the Qdrant server.
        recreate_collection (bool): Whether to recreate the collection if it already exists.

    Attributes:
        client (qdrant_client.QdrantClient): The Qdrant client object for interacting with the API.
    """

    def __init__(
        self,
        collection_name=COLLECTION_NAME,
        collection_size: int = COLLECTION_SIZE,
        port: int = QDRANT_PORT,
        host=QDRANT_HOST,
        api_key=QDRANT_API_KEY,
        recreate_collection: bool = False,
    ):
        self.collection_name = collection_name
        self.collection_size = collection_size
        self.host = host
        self.port = port
        self.api_key = api_key

        self.client = QdrantClient(host=host, port=port, api_key=api_key)
        self.setup_collection(collection_size, recreate_collection)

    def setup_collection(self, collection_size: int, recreate_collection: bool):
        if recreate_collection:
            self.recreate_collection()

        try:
            collection_info = self.get_collection_info()
            current_collection_size = collection_info["vector_size"]

            if current_collection_size != int(collection_size):
                raise ValueError(
                    f"""
                    Existing collection {self.collection_name} has different collection size
                    To use the new collection configuration, you need to recreate the collection as it already exists with a different configuration.
                    use recreate_collection = True.
                    """
                )

        except Exception as e:
            self.recreate_collection()
            print(e)

    def recreate_collection(self):
        self.client.recreate_collection(
            collection_name=self.collection_name,
            vectors_config=models.VectorParams(
                size=self.collection_size, distance=models.Distance.COSINE
            ),
        )

    def get_collection_info(self):
        collection_info = self.client.get_collection(
            collection_name=self.collection_name
        )

        return {
            "points_count": int(collection_info.points_count),
            "vectors_count": int(collection_info.vectors_count),
            "indexed_vectors_count": int(collection_info.indexed_vectors_count),
            "vector_size": int(collection_info.config.params.vectors.size),
        }

    def upsert_point(self, id, payload, embedding):
        response = self.client.upsert(
            collection_name=self.collection_name,
            points=[
                models.PointStruct(
                    id=id,
                    payload=payload,
                    vector=embedding,
                ),
            ],
        )

        return response

    def upsert_points(self, ids, payloads, embeddings):
        response = self.client.upsert(
            collection_name=self.collection_name,
            points=models.Batch(
                ids=ids,
                payloads=payloads,
                vectors=embeddings,
            ),
        )

        return response

    def search_point(self, query_vector, user_id, document_id, limit):
        response = self.client.search(
            collection_name=self.collection_name,
            query_filter=models.Filter(
                must=[
                    models.FieldCondition(
                        key="user_id",
                        match=models.MatchValue(
                            value=user_id,
                        ),
                    ),
                    models.FieldCondition(
                        key="document_id",
                        match=models.MatchValue(value=document_id),
                    ),
                ]
            ),
            query_vector=query_vector,
            limit=limit,
        )

        return response

    def delete_collection(self):
        response = self.client.delete_collection(collection_name=self.collection_name)

        return response


qdrant_manager = QdrantManager()