import os import logging from abc import ABC, abstractmethod from typing import List, Dict, Any from sentence_transformers import SentenceTransformer from pymilvus import MilvusClient, DataType import time import gradio as gr # 配置日志 logging.basicConfig( level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s' ) logger = logging.getLogger(__name__) models = [ 'shibing624/text2vec-base-chinese', 'BAAI/bge-small-zh', 'BAAI/bge-base-zh', 'sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2', 'all-MiniLM-L6-v2', 'all-MiniLM-L12-v2', 'multi-qa-mpnet-base-dot-v1', # 'bge-small-en-v1.5', 不兼容 'all-mpnet-base-v2', 'jinaai/jina-embeddings-v3', ] searchers = {} class BaseEmbeddingModel(ABC): @abstractmethod def encode(self, text: str) -> List[float]: pass @property @abstractmethod def dimension(self) -> int: pass @property @abstractmethod def model_name(self) -> str: pass class SentenceTransformerModel(BaseEmbeddingModel): def __init__(self, model_name: str): self.model = SentenceTransformer(model_name, trust_remote_code=True) self._model_name = model_name def encode(self, text: str) -> List[float]: result = self.model.encode(text).tolist() return result @property def dimension(self) -> int: return self.model.get_sentence_embedding_dimension() @property def model_name(self) -> str: return self._model_name class StickerSearcher: def __init__(self, model: BaseEmbeddingModel): self.model = model self.client = MilvusClient(uri='./sticker.db') self.collection_name = f'test_{model.model_name.replace("/", "_").replace("-", "_")}' def init_collection(self) -> bool: try: self.client.drop_collection(collection_name=self.collection_name) self.client.create_collection( collection_name=self.collection_name, dimension=self.model.dimension, primary_field_name='id', auto_id=True ) self.client.create_index( collection_name=self.collection_name, index_type='IVF_SQ8', metric_type='COSINE', params={'nlist': 128}, index_params={} ) self.client.load_collection(self.collection_name) logger.info(f'Collection initialized: {self.collection_name}') return True except Exception as e: logger.error(f'Collection init failed: {str(e)}') return False def store_vector(self, title: str, description: str, tags: List[str], file_path: str): vector = self.model.encode(description) data = [{ 'vector': vector, 'title': title, 'description': description, 'tags': tags, 'file_name': file_path }] self.client.insert(self.collection_name, data) def search(self, query: str, limit: int = 5) -> List[Dict[str, Any]]: start_time = time.time() query_vector = self.model.encode(query) encode_time = time.time() - start_time start_search_time = time.time() results = self.client.search( collection_name=self.collection_name, data=[query_vector], limit=limit, output_fields=['title', 'description', 'tags', 'file_name'] ) search_time = time.time() - start_search_time total_time = encode_time + search_time logger.info(f'模型 {self.model.model_name} Encoding耗时: ${encode_time:.4f},搜索耗时: {search_time:.4f} 秒, 总耗时: {total_time:.4f} 秒') return results[0] def create_gradio_ui(): async def search_model(model_name: str, query: str): try: if model_name in searchers: return searchers[model_name].search(query) logger.error(f'Model not loaded: {model_name}') return [] except Exception as e: logger.error(f'Search failed: {model_name} | Error: {str(e)}') return [] async def search_all_models(query): if not query: return [] print(f'>>>> Searching From Models {query}') results = [] for model_name in models: result = await search_model(model_name, query) results.append(result) formatted_results = [] max_results = max(len(r) for r in results) for i in range(max_results): row = [i + 1] for model_results in results: if i < len(model_results): result = model_results[i] image_url = f'https://huggingface.co/datasets/Nekoko/StickerSet/resolve/main/{result["entity"]["file_name"]}' row.append(f'![Sticker]({image_url})\n相似度: {result["distance"]:.4f}') else: row.append('-') formatted_results.append(row) return formatted_results def init_collections(): try: client = MilvusClient(uri='./sticker.db') stickers = client.query( collection_name='stickers', filter='', limit=1000, output_fields=['title', 'description', 'tags', 'file_name'] ) logger.info(f'Stickers loaded: {len(stickers)}') def init_model(model_name): try: searcher = StickerSearcher(SentenceTransformerModel(model_name)) if searcher.init_collection(): searchers[model_name] = searcher for sticker in stickers: searcher.store_vector( sticker.get('title'), sticker.get('description'), sticker.get('tags'), sticker.get('file_name') ) logger.info(f'Model initialized: {model_name}') except Exception as e: logger.error(f'Model init failed: {model_name} | Error: {str(e)}') for model_name in models: print(f'>>>> 初始化模型 {model_name}') start_time = time.time() init_model(model_name) print(f'>>>> 初始化模型 {model_name} 完成 ✅,耗时 {time.time() - start_time:.4f} 秒') print(f'>>>> 初始化所有模型完成 ✅') return '初始化成功!' except Exception as e: logger.error(f'Data init failed: {str(e)}') return f'初始化失败: {str(e)}' with gr.Blocks(title='Neko Sticker Search 🔍', css='.gradio-container img { width: 200px !important; height: 200px !important; object-fit: contain; }') as demo: with gr.Row(): search_input = gr.Textbox(label='搜索关键词') search_button = gr.Button('搜索') headers = ['序号'] + [f'🧊{model.split("/")[-1]}' for i, model in enumerate(models)] results_table = gr.Dataframe( headers=headers, datatype=['number'] + ['markdown'] * len(models), row_count=5, col_count=len(models) + 1 ) status_box = gr.Textbox(label='状态', interactive=False) refresh_button = gr.Button('刷新数据') refresh_button.click(fn=init_collections, outputs=status_box) # 由于这里只是简单的搜索操作,可以直接使用同步方式调用 search_button.click( fn=search_all_models, inputs=[search_input], outputs=results_table ) return demo if __name__ == '__main__': # 提前加载所有模型 start_time = time.time() for index, model_name in enumerate(models): try: start_time = time.time() searchers[model_name] = StickerSearcher(SentenceTransformerModel(model_name)) print(f'>>>> 预加载模型 {model_name} 完成 ✅, 耗时 {time.time() - start_time:.4f} 秒') except Exception as e: logger.error(f'Model preload failed: {model_name} | Error: {str(e)}') logger.info(f'>>>> 预加载模型完成 ✅: {models}, 耗时 {time.time() - start_time:.4f} 秒') demo = create_gradio_ui() demo.launch()