import os import logging from typing import List, Dict, Any, Optional, Union from PIL import Image from app.api import get_chat_completion import json from app.config import ( STICKER_RERANKING_SYSTEM_PROMPT, PUBLIC_URL ) from app.database import db from app.image_utils import ( save_image_temp, upload_to_huggingface, get_image_cdn_url, get_image_description, calculate_image_hash ) from app.gradio_formatter import gradio_formatter # 配置日志 logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') logger = logging.getLogger(__name__) class StickerService: """贴纸服务类,处理贴纸的上传、搜索等业务逻辑""" @staticmethod def upload_sticker(image_file_path: str, title: str, description: str, tags: str) -> str: """上传贴纸""" try: # 打开图片 image = Image.open(image_file_path) # 检查文件名是否已存在 image_hash = calculate_image_hash(image) if db.check_image_exists(image_hash): print(f"文件已存在", image_hash) raise Exception('File_Exists') # 上传到 HuggingFace file_path, image_filename = upload_to_huggingface(image_file_path) # print('>>>> image_file_path', image_file_path) # print('>>>> image_filename', image_filename) # print('>>>> file_path', file_path) # 如果没有描述,获取图片描述 if not description: image_cdn_url = '' if (PUBLIC_URL): image_cdn_url = f'{PUBLIC_URL}/gradio_api/file={image_file_path}' else: image_cdn_url = get_image_cdn_url(file_path) print('image_cdn_url',image_cdn_url) description = get_image_description(image_cdn_url) # 清理临时文件 # os.unlink(temp_file_path) # 存储到 Milvus db.store_sticker(title, description, tags, file_path, image_hash) return f"Upload successful! {image_filename}" except Exception as e: logger.error(f"Upload failed: {str(e)}") return f"Upload failed: {str(e)}" @staticmethod def search_stickers(description: str, limit: int = 2, reranking : bool = False) -> List[Dict[str, Any]]: """搜索贴纸""" if not description: return [] try: results = db.search_stickers(description, limit) if (reranking): # 对搜索结果进行重排 results = StickerService.rerank_search_results(description, results, limit) return results except Exception as e: logger.error(f"Search failed: {str(e)}") return [] @staticmethod def get_all_stickers(limit: int = 1000) -> List[List]: """获取所有贴纸""" try: results = db.get_all_stickers(limit) return gradio_formatter.format_all_stickers(results) except Exception as e: logger.error(f"Failed to get all stickers: {str(e)}") return [] @staticmethod def delete_sticker(sticker_id: str) -> str: """删除贴纸""" try: # 首先查询贴纸是否存在 result = db.delete_sticker(sticker_id) return f"Sticker with ID {sticker_id} deleted successfully" except Exception as e: logger.error(f"Delete failed: {str(e)}") return f"Delete failed: {str(e)}" @staticmethod def rerank_search_results(query: str, sticker_list: List[Dict[str, Any]], limit: int = 5) -> List[Dict[str, Any]]: ## 使用 LLM 模型重新排序搜索结果 try: # 构建提示词 system_prompt = STICKER_RERANKING_SYSTEM_PROMPT # 构建用户提示词,包含查询和表情包信息 _sticker_list = [] for hit in sticker_list: _sticker_list.append({ "id": hit["id"], "description": hit["entity"]["description"] }) user_prompt = f"请分析关键词 '{query}' 与以下表情包的相关性:\n{_sticker_list}" print(f">>> 使用 LLM 模型重新排序....", user_prompt, system_prompt) # 调用 LLM 模型获取重排序结果 response = get_chat_completion(user_prompt, system_prompt) # 解析 LLM 返回的 JSON 结果 reranked_stickers = json.loads(response) # 验证返回结果格式 if not isinstance(reranked_stickers, list): raise ValueError("Invalid response format") # 按分数排序 reranked_stickers.sort(key=lambda x: float(x.get("score", 0)), reverse=True) print(f">>> LLM 排序结果", reranked_stickers) # 将重排序结果与原始结果对应 rerank_results = [] for sticker in reranked_stickers: for hit in sticker_list: if str(hit["id"]) == str(sticker["sticker_id"]): hit["entity"]["score"] = sticker["score"] hit["entity"]["reason"] = sticker["reason"] rerank_results.append(hit) break print(f">>> rerank_results", rerank_results) return rerank_results except Exception as e: logger.error(f"Reranking failed: {str(e)}") return [] # 创建服务实例 sticker_service = StickerService()