from pymilvus import MilvusClient from sentence_transformers import SentenceTransformer from typing import List, Dict, Any, Optional, Union import logging from app.config import MILVUS_DB_URL, MILVUS_DB_TOKEN, EMBEDDING_MODEL, DATASET_ID # 配置日志 logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') logger = logging.getLogger(__name__) class Database: """数据库操作类,处理与Milvus的交互""" def __init__(self): self.client = MilvusClient( uri = MILVUS_DB_URL, token= MILVUS_DB_TOKEN) self.model = SentenceTransformer(EMBEDDING_MODEL, trust_remote_code=True) print('初始化模型完成',self.model) self.collection_name = "stickers" def init_collection(self) -> bool: """初始化 Milvus 数据库""" try: print('初始化 Milvus 数据库', self.client.list_collections()) if not len(self.client.list_collections()) > 0: self.client.create_collection( collection_name=self.collection_name, dimension=768, primary_field="id", auto_id=True ) self.client.create_index( collection_name=self.collection_name, index_type="IVF_SQ8", metric_type="COSINE", params={"nlist": 128}, index_params={} ) logger.info(f"Collection initialized: {self.collection_name}") print('初始化 Milvus 数据库成功', self.client.list_collections()) return True except Exception as e: logger.error(f"Collection initialization failed: {str(e)}") return False def encode_text(self, text: str) -> List[float]: """将文本编码为向量""" return self.model.encode(text).tolist() def store_sticker(self, title: str, description: str, tags: Union[str, List[str]], file_path: str, image_hash: str = None) -> bool: """存储贴纸数据到Milvus""" try: vector = self.encode_text(description) # 处理标签格式 if isinstance(tags, str): tags = tags.split(",") logger.info(f"Storing to Milvus - title: {title}, description: {description}, file_path: {file_path}, tags: {tags}, image_hash: {image_hash}") self.client.insert( collection_name=self.collection_name, data=[{ "vector": vector, "title": title, "description": description, "tags": tags, "file_name": file_path, "image_hash": image_hash }] ) logger.info("Storing to Milvus Success ✅") return True except Exception as e: logger.error(f"Failed to store sticker: {str(e)}") return False def search_stickers(self, description: str, limit: int = 2) -> List[Dict[str, Any]]: """搜索贴纸""" if not description: return [] try: text_vector = self.encode_text(description) logger.info(f"Searching Milvus - query: {description}, limit: {limit}") results = self.client.search( collection_name=self.collection_name, data=[text_vector], limit=limit, search_params={ "metric_type": "COSINE", }, output_fields=["title", "description", "tags", "file_name"], ) logger.info(f"Search Result: {results}") return results[0] except Exception as e: logger.error(f"Search failed: {str(e)}") return [] def get_all_stickers(self, limit: int = 1000) -> List[Dict[str, Any]]: """获取所有贴纸""" try: results = self.client.query( collection_name=self.collection_name, filter="", limit=limit, output_fields=["title", "description", "tags", "file_name", "image_hash"] ) logger.info(f"Query All Stickers - limit: {limit}, results count: {len(results)}") return results except Exception as e: logger.error(f"Failed to get all stickers: {str(e)}") return [] def check_image_exists(self, image_hash: str) -> bool: """检查文件名是否已存在""" try: results = self.client.query( collection_name=self.collection_name, filter=f"image_hash == '{image_hash}'", limit=1, output_fields=["file_name", "image_hash"] ) exists = len(results) > 0 logger.info(f"Check file exists - hash: {image_hash}, exists: {exists}, results: {results}") return exists except Exception as e: logger.error(f"Failed to check file exists: {str(e)}") return False def delete_sticker(self, sticker_id: int) -> str: """删除贴纸""" try: logger.info(f"Deleting sticker - id: {sticker_id}") res = self.client.delete( collection_name=self.collection_name, ids=[sticker_id] ) logger.info(f"Deleted sticker - id: {sticker_id}") print(res) return f"Sticker with ID {sticker_id} deleted successfully" except Exception as e: logger.error(f"Failed to delete sticker: {str(e)}") return f"Failed to delete sticker: {str(e)}" # 初始化 Milvus 数据库 # 创建数据库实例 db = Database()