Spaces:
Runtime error
Runtime error
File size: 4,872 Bytes
66340f1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
import os
from qdrant_client import QdrantClient
from qdrant_client.http import models
# from dotenv import load_dotenv
from uuid import uuid4
# load_dotenv()
COLLECTION_NAME = os.getenv("COLLECTION_NAME")
COLLECTION_SIZE = os.getenv("COLLECTION_SIZE")
QDRANT_PORT = os.getenv("QDRANT_PORT")
QDRANT_HOST = os.getenv("QDRANT_HOST")
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
class QdrantManager:
"""
A class for managing collectionsget_collection_info in the Qdrant database.
Args:
collection_name (str): The name of the collection to manage.
collection_size (int): The maximum number of documents in the collection.
port (int): The port number for the Qdrant API.
host (str): The hostname or IP address for the Qdrant server.
api_key (str): The API key for authenticating with the Qdrant server.
recreate_collection (bool): Whether to recreate the collection if it already exists.
Attributes:
client (qdrant_client.QdrantClient): The Qdrant client object for interacting with the API.
"""
def __init__(
self,
collection_name=COLLECTION_NAME,
collection_size: int = COLLECTION_SIZE,
port: int = QDRANT_PORT,
host=QDRANT_HOST,
api_key=QDRANT_API_KEY,
recreate_collection: bool = False,
):
self.collection_name = collection_name
self.collection_size = collection_size
self.host = host
self.port = port
self.api_key = api_key
self.client = QdrantClient(host=host, port=port, api_key=api_key)
self.setup_collection(collection_size, recreate_collection)
def setup_collection(self, collection_size: int, recreate_collection: bool):
if recreate_collection:
self.recreate_collection()
try:
collection_info = self.get_collection_info()
current_collection_size = collection_info["vector_size"]
if current_collection_size != int(collection_size):
raise ValueError(
f"""
Existing collection {self.collection_name} has different collection size
To use the new collection configuration, you need to recreate the collection as it already exists with a different configuration.
use recreate_collection = True.
"""
)
except Exception as e:
self.recreate_collection()
print(e)
def recreate_collection(self):
self.client.recreate_collection(
collection_name=self.collection_name,
vectors_config=models.VectorParams(
size=self.collection_size, distance=models.Distance.COSINE
),
)
def get_collection_info(self):
collection_info = self.client.get_collection(
collection_name=self.collection_name
)
return {
"points_count": int(collection_info.points_count),
"vectors_count": int(collection_info.vectors_count),
"indexed_vectors_count": int(collection_info.indexed_vectors_count),
"vector_size": int(collection_info.config.params.vectors.size),
}
def upsert_point(self, id, payload, embedding):
response = self.client.upsert(
collection_name=self.collection_name,
points=[
models.PointStruct(
id=id,
payload=payload,
vector=embedding,
),
],
)
return response
def upsert_points(self, ids, payloads, embeddings):
response = self.client.upsert(
collection_name=self.collection_name,
points=models.Batch(
ids=ids,
payloads=payloads,
vectors=embeddings,
),
)
return response
def search_point(self, query_vector, user_id, document_id, limit):
response = self.client.search(
collection_name=self.collection_name,
query_filter=models.Filter(
must=[
models.FieldCondition(
key="user_id",
match=models.MatchValue(
value=user_id,
),
),
models.FieldCondition(
key="document_id",
match=models.MatchValue(value=document_id),
),
]
),
query_vector=query_vector,
limit=limit,
)
return response
def delete_collection(self):
response = self.client.delete_collection(collection_name=self.collection_name)
return response
qdrant_manager = QdrantManager()
|