NekoStickers / app /database.py
zhangfeng144
add batch upload
66e44c7
raw
history blame contribute delete
8.35 kB
from pymilvus import MilvusClient
from sentence_transformers import SentenceTransformer
from typing import List, Dict, Any, Optional, Union
import logging
from app.config import MILVUS_DB_URL, MILVUS_DB_TOKEN, EMBEDDING_MODEL, DATASET_ID
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
logger = logging.getLogger(__name__)
class Database:
"""数据库操作类,处理与Milvus的交互"""
def __init__(self):
self.client = MilvusClient(
uri = MILVUS_DB_URL,
token= MILVUS_DB_TOKEN)
self.model = SentenceTransformer(EMBEDDING_MODEL, trust_remote_code=True)
print('初始化模型完成',self.model)
self.collection_name = "stickers"
def init_collection(self) -> bool:
"""初始化 Milvus 数据库"""
try:
print('初始化 Milvus 数据库', self.client.list_collections())
if not len(self.client.list_collections()) > 0:
self.client.create_collection(
collection_name=self.collection_name,
dimension=768,
primary_field="id",
auto_id=True
)
self.client.create_index(
collection_name=self.collection_name,
index_type="IVF_SQ8",
metric_type="COSINE",
params={"nlist": 128},
index_params={}
)
logger.info(f"Collection initialized: {self.collection_name}")
print('初始化 Milvus 数据库成功', self.client.list_collections())
return True
except Exception as e:
logger.error(f"Collection initialization failed: {str(e)}")
return False
def encode_text(self, text: str) -> List[float]:
"""将文本编码为向量"""
return self.model.encode(text).tolist()
def store_sticker(self, title: str, description: str, tags: Union[str, List[str]], file_path: str, image_hash: str = None) -> bool:
"""存储贴纸数据到Milvus"""
try:
vector = self.encode_text(description)
# 处理标签格式
if isinstance(tags, str):
tags = tags.split(",")
logger.info(f"Storing to Milvus - title: {title}, description: {description}, file_path: {file_path}, tags: {tags}, image_hash: {image_hash}")
self.client.insert(
collection_name=self.collection_name,
data=[{
"vector": vector,
"title": title,
"description": description,
"tags": tags,
"file_name": file_path,
"image_hash": image_hash
}]
)
logger.info("Storing to Milvus Success ✅")
return True
except Exception as e:
logger.error(f"Failed to store sticker: {str(e)}")
return False
def search_stickers(self, description: str, limit: int = 2) -> List[Dict[str, Any]]:
"""搜索贴纸"""
if not description:
return []
try:
text_vector = self.encode_text(description)
logger.info(f"Searching Milvus - query: {description}, limit: {limit}")
results = self.client.search(
collection_name=self.collection_name,
data=[text_vector],
limit=limit,
search_params={
"metric_type": "COSINE",
},
output_fields=["title", "description", "tags", "file_name"],
)
logger.info(f"Search Result: {results}")
return results[0]
except Exception as e:
logger.error(f"Search failed: {str(e)}")
return []
def get_all_stickers(self, limit: int = 1000) -> List[Dict[str, Any]]:
"""获取所有贴纸"""
try:
results = self.client.query(
collection_name=self.collection_name,
filter="",
limit=limit,
output_fields=["title", "description", "tags", "file_name", "image_hash"]
)
logger.info(f"Query All Stickers - limit: {limit}, results count: {len(results)}")
return results
except Exception as e:
logger.error(f"Failed to get all stickers: {str(e)}")
return []
def check_image_exists(self, image_hash: str) -> bool:
"""检查文件名是否已存在"""
try:
results = self.client.query(
collection_name=self.collection_name,
filter=f"image_hash == '{image_hash}'",
limit=1,
output_fields=["file_name", "image_hash"]
)
exists = len(results) > 0
return exists
except Exception as e:
logger.error(f"Failed to check file exists: {str(e)}")
return False
def delete_sticker(self, sticker_id: int) -> str:
"""删除贴纸"""
try:
logger.info(f"Deleting sticker - id: {sticker_id}")
res = self.client.delete(
collection_name=self.collection_name,
ids=[sticker_id]
)
logger.info(f"Deleted sticker - id: {sticker_id}")
print(res)
return f"Sticker with ID {sticker_id} deleted successfully"
except Exception as e:
logger.error(f"Failed to delete sticker: {str(e)}")
return f"Failed to delete sticker: {str(e)}"
def batch_store_stickers(self, stickers: List[Dict[str, Any]], batch_size: int = 100) -> bool:
"""批量存储贴纸数据到Milvus
Args:
stickers (List[Dict[str, Any]]): 贴纸数据列表,每个元素包含以下字段:
- title: str
- description: str
- tags: Union[str, List[str]]
- file_path: str
- image_hash: str (可选)
batch_size (int, optional): 每批处理的数量. Defaults to 100.
Returns:
bool: 是否全部插入成功
"""
try:
total_stickers = len(stickers)
if total_stickers == 0:
logger.warning("No stickers to store")
return True
logger.info(f"Starting batch store of {total_stickers} stickers")
# 分批处理
for i in range(0, total_stickers, batch_size):
batch = stickers[i:i + batch_size]
batch_data = []
for sticker in batch:
# 处理标签格式
tags = sticker.get("tags", [])
if isinstance(tags, str):
tags = tags.split(",")
# 编码描述文本
vector = self.encode_text(sticker.get("description", ""))
batch_data.append({
"vector": vector,
"title": sticker.get("title", ""),
"description": sticker.get("description", ""),
"tags": tags,
"file_name": sticker.get("file_path", ""),
"image_hash": sticker.get("image_hash")
})
# 批量插入
if batch_data:
self.client.insert(
collection_name=self.collection_name,
data=batch_data
)
logger.info(f"Batch {i//batch_size + 1} stored successfully - {len(batch_data)} stickers")
logger.info("All stickers stored successfully ✅")
return True
except Exception as e:
logger.error(f"Failed to batch store stickers: {str(e)}")
return False
# 初始化 Milvus 数据库
# 创建数据库实例
db = Database()