Spaces:
Running
Running
from pymilvus import MilvusClient | |
from sentence_transformers import SentenceTransformer | |
from typing import List, Dict, Any, Optional, Union | |
import logging | |
from app.config import MILVUS_DB_URL, MILVUS_DB_TOKEN, EMBEDDING_MODEL, DATASET_ID | |
# 配置日志 | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') | |
logger = logging.getLogger(__name__) | |
class Database: | |
"""数据库操作类,处理与Milvus的交互""" | |
def __init__(self): | |
self.client = MilvusClient( | |
uri = MILVUS_DB_URL, | |
token= MILVUS_DB_TOKEN) | |
self.model = SentenceTransformer(EMBEDDING_MODEL, trust_remote_code=True) | |
print('初始化模型完成',self.model) | |
self.collection_name = "stickers" | |
def init_collection(self) -> bool: | |
"""初始化 Milvus 数据库""" | |
try: | |
print('初始化 Milvus 数据库', self.client.list_collections()) | |
if not len(self.client.list_collections()) > 0: | |
self.client.create_collection( | |
collection_name=self.collection_name, | |
dimension=768, | |
primary_field="id", | |
auto_id=True | |
) | |
self.client.create_index( | |
collection_name=self.collection_name, | |
index_type="IVF_SQ8", | |
metric_type="COSINE", | |
params={"nlist": 128}, | |
index_params={} | |
) | |
logger.info(f"Collection initialized: {self.collection_name}") | |
print('初始化 Milvus 数据库成功', self.client.list_collections()) | |
return True | |
except Exception as e: | |
logger.error(f"Collection initialization failed: {str(e)}") | |
return False | |
def encode_text(self, text: str) -> List[float]: | |
"""将文本编码为向量""" | |
return self.model.encode(text).tolist() | |
def store_sticker(self, title: str, description: str, tags: Union[str, List[str]], file_path: str, image_hash: str = None) -> bool: | |
"""存储贴纸数据到Milvus""" | |
try: | |
vector = self.encode_text(description) | |
# 处理标签格式 | |
if isinstance(tags, str): | |
tags = tags.split(",") | |
logger.info(f"Storing to Milvus - title: {title}, description: {description}, file_path: {file_path}, tags: {tags}, image_hash: {image_hash}") | |
self.client.insert( | |
collection_name=self.collection_name, | |
data=[{ | |
"vector": vector, | |
"title": title, | |
"description": description, | |
"tags": tags, | |
"file_name": file_path, | |
"image_hash": image_hash | |
}] | |
) | |
logger.info("Storing to Milvus Success ✅") | |
return True | |
except Exception as e: | |
logger.error(f"Failed to store sticker: {str(e)}") | |
return False | |
def search_stickers(self, description: str, limit: int = 2) -> List[Dict[str, Any]]: | |
"""搜索贴纸""" | |
if not description: | |
return [] | |
try: | |
text_vector = self.encode_text(description) | |
logger.info(f"Searching Milvus - query: {description}, limit: {limit}") | |
results = self.client.search( | |
collection_name=self.collection_name, | |
data=[text_vector], | |
limit=limit, | |
search_params={ | |
"metric_type": "COSINE", | |
}, | |
output_fields=["title", "description", "tags", "file_name"], | |
) | |
logger.info(f"Search Result: {results}") | |
return results[0] | |
except Exception as e: | |
logger.error(f"Search failed: {str(e)}") | |
return [] | |
def get_all_stickers(self, limit: int = 1000) -> List[Dict[str, Any]]: | |
"""获取所有贴纸""" | |
try: | |
results = self.client.query( | |
collection_name=self.collection_name, | |
filter="", | |
limit=limit, | |
output_fields=["title", "description", "tags", "file_name", "image_hash"] | |
) | |
logger.info(f"Query All Stickers - limit: {limit}, results count: {len(results)}") | |
return results | |
except Exception as e: | |
logger.error(f"Failed to get all stickers: {str(e)}") | |
return [] | |
def check_image_exists(self, image_hash: str) -> bool: | |
"""检查文件名是否已存在""" | |
try: | |
results = self.client.query( | |
collection_name=self.collection_name, | |
filter=f"image_hash == '{image_hash}'", | |
limit=1, | |
output_fields=["file_name", "image_hash"] | |
) | |
exists = len(results) > 0 | |
return exists | |
except Exception as e: | |
logger.error(f"Failed to check file exists: {str(e)}") | |
return False | |
def delete_sticker(self, sticker_id: int) -> str: | |
"""删除贴纸""" | |
try: | |
logger.info(f"Deleting sticker - id: {sticker_id}") | |
res = self.client.delete( | |
collection_name=self.collection_name, | |
ids=[sticker_id] | |
) | |
logger.info(f"Deleted sticker - id: {sticker_id}") | |
print(res) | |
return f"Sticker with ID {sticker_id} deleted successfully" | |
except Exception as e: | |
logger.error(f"Failed to delete sticker: {str(e)}") | |
return f"Failed to delete sticker: {str(e)}" | |
def batch_store_stickers(self, stickers: List[Dict[str, Any]], batch_size: int = 100) -> bool: | |
"""批量存储贴纸数据到Milvus | |
Args: | |
stickers (List[Dict[str, Any]]): 贴纸数据列表,每个元素包含以下字段: | |
- title: str | |
- description: str | |
- tags: Union[str, List[str]] | |
- file_path: str | |
- image_hash: str (可选) | |
batch_size (int, optional): 每批处理的数量. Defaults to 100. | |
Returns: | |
bool: 是否全部插入成功 | |
""" | |
try: | |
total_stickers = len(stickers) | |
if total_stickers == 0: | |
logger.warning("No stickers to store") | |
return True | |
logger.info(f"Starting batch store of {total_stickers} stickers") | |
# 分批处理 | |
for i in range(0, total_stickers, batch_size): | |
batch = stickers[i:i + batch_size] | |
batch_data = [] | |
for sticker in batch: | |
# 处理标签格式 | |
tags = sticker.get("tags", []) | |
if isinstance(tags, str): | |
tags = tags.split(",") | |
# 编码描述文本 | |
vector = self.encode_text(sticker.get("description", "")) | |
batch_data.append({ | |
"vector": vector, | |
"title": sticker.get("title", ""), | |
"description": sticker.get("description", ""), | |
"tags": tags, | |
"file_name": sticker.get("file_path", ""), | |
"image_hash": sticker.get("image_hash") | |
}) | |
# 批量插入 | |
if batch_data: | |
self.client.insert( | |
collection_name=self.collection_name, | |
data=batch_data | |
) | |
logger.info(f"Batch {i//batch_size + 1} stored successfully - {len(batch_data)} stickers") | |
logger.info("All stickers stored successfully ✅") | |
return True | |
except Exception as e: | |
logger.error(f"Failed to batch store stickers: {str(e)}") | |
return False | |
# 初始化 Milvus 数据库 | |
# 创建数据库实例 | |
db = Database() |