Spaces:
Running
Running
import os | |
import logging | |
import random | |
import zipfile | |
import tempfile | |
import shutil | |
import json | |
from typing import List, Dict, Any, Optional, Union | |
from PIL import Image | |
from app.api import get_chat_completion | |
from app.config import ( | |
STICKER_RERANKING_SYSTEM_PROMPT, | |
PUBLIC_URL, | |
TEMP_DIR | |
) | |
from app.database import db | |
from app.image_utils import ( | |
save_image_temp, | |
generate_temp_image, | |
upload_folder_to_huggingface, | |
upload_to_huggingface, | |
get_image_cdn_url, | |
get_image_description, | |
calculate_image_hash | |
) | |
from app.gradio_formatter import gradio_formatter | |
from multiprocessing import Queue | |
# 配置日志 | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') | |
logger = logging.getLogger(__name__) | |
class StickerService: | |
"""贴纸服务类,处理贴纸的上传、搜索等业务逻辑""" | |
def upload_sticker(image_file_path: str, title: str, description: str, tags: str) -> str: | |
"""上传贴纸""" | |
try: | |
# 打开图片 | |
image = Image.open(image_file_path) | |
# 检查文件名是否已存在 | |
image_hash = calculate_image_hash(image) | |
if db.check_image_exists(image_hash): | |
print(f"文件已存在", image_hash) | |
raise Exception('File_Exists') | |
# 上传到 HuggingFace | |
file_path, image_filename = upload_to_huggingface(image_file_path) | |
# print('>>>> image_file_path', image_file_path) | |
# print('>>>> image_filename', image_filename) | |
# print('>>>> file_path', file_path) | |
# 如果没有描述,获取图片描述 | |
if not description: | |
image_cdn_url = '' | |
if (PUBLIC_URL): | |
image_cdn_url = f'{PUBLIC_URL}/gradio_api/file={image_file_path}' | |
else: | |
image_cdn_url = get_image_cdn_url(file_path) | |
print('image_cdn_url',image_cdn_url) | |
description = get_image_description(image_cdn_url) | |
# 清理临时文件 | |
# os.unlink(temp_file_path) | |
# 存储到 Milvus | |
db.store_sticker(title, description, tags, file_path, image_hash) | |
return f"Upload successful! {image_filename}" | |
except Exception as e: | |
logger.error(f"Upload failed: {str(e)}") | |
return f"Upload failed: {str(e)}" | |
def import_stickers( | |
sticker_dataset: str, | |
upload: bool = False, | |
save_to_milvus: bool = False, | |
progress_callback: callable = None, | |
) -> List[str]: | |
"""导入表情包数据集 | |
Args: | |
sticker_dataset (str): 表情包数据集路径 | |
upload (bool, optional): 是否上传到HuggingFace. Defaults to False. | |
progress_callback (callable, optional): 进度回调函数. Defaults to None. | |
""" | |
results = [] | |
descriptions = {} | |
try: | |
# 创建临时目录 | |
cache_folder = os.path.join(TEMP_DIR, 'cache/') | |
img_folder = os.path.join(TEMP_DIR, 'images/') | |
data_json_path = os.path.join(cache_folder, 'data.json') | |
stickers = [] | |
logger.info(f"start import dataset") | |
# 解压数据集 | |
with zipfile.ZipFile(sticker_dataset, 'r') as zip_ref: | |
zip_ref.extractall(cache_folder) | |
logger.info(f"Extracted dataset to: {cache_folder}") | |
# 尝试读取data.json文件 | |
if os.path.exists(data_json_path): | |
with open(data_json_path, 'r', encoding='utf-8') as f: | |
data = json.load(f) | |
descriptions = { x["filename"]: x["content"] for x in data } | |
logger.info(f"Loaded descriptions from data.json") | |
# 遍历解压后的目录 | |
for root, dirs, files in os.walk(cache_folder): | |
for file in files: | |
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.webp')): | |
image_path = os.path.join(root, file) | |
try: | |
# 打开图片 | |
image = Image.open(image_path) | |
image_hash = calculate_image_hash(image) | |
if db.check_image_exists(image_hash): | |
results.append(f"跳过已存在的图片: {file}") | |
if progress_callback: | |
progress_callback(file, "Skipped (exists)") | |
continue | |
# 获取图片描述 | |
description = None | |
if file in descriptions: | |
description = descriptions[file] | |
if not description: | |
results.append(f"跳过无描述的图片: {file}") | |
if progress_callback: | |
progress_callback(file, "Skipped (no description)") | |
continue | |
image_filename = f"image_{random.randint(100000, 999999)}.png" | |
file_path = f"images/{image_filename}" | |
generate_temp_image(img_folder, image, image_filename) | |
if save_to_milvus: | |
db.store_sticker("", description, "", file_path, image_hash) | |
stickers.append({ | |
"title": "", | |
"description": description, | |
"tags": "", | |
"file_path": file_path, | |
"image_hash": image_hash | |
}) | |
if progress_callback: | |
results.append(f"成功导入: {image_filename}") | |
progress_callback(file, "Imported") | |
except Exception as e: | |
logger.error(f"Failed to process image {file}: {str(e)}") | |
results.append(f"处理失败 {file}: {str(e)}") | |
if progress_callback: | |
progress_callback(file, f"Failed: {str(e)}") | |
# 上传到 HuggingFace | |
if upload and len(stickers) > 0: | |
logger.info(f"upload to huggingface, {len(stickers)} stickers") | |
upload_folder_to_huggingface(img_folder) | |
results.append(f"上传到 HuggingFace 成功") | |
return results | |
except Exception as e: | |
logger.error(f"Import failed: {str(e)}") | |
results.append(f"导入失败: {str(e)}") | |
return results | |
finally: | |
# 清理临时目录 | |
if cache_folder and os.path.exists(cache_folder): | |
shutil.rmtree(cache_folder) | |
logger.info(f"Cleaned up temporary directory: {cache_folder}") | |
if img_folder and os.path.exists(img_folder): | |
shutil.rmtree(img_folder) | |
logger.info(f"Cleaned up temporary directory: {img_folder}") | |
def search_stickers(description: str, limit: int = 2, reranking : bool = False) -> List[Dict[str, Any]]: | |
"""搜索贴纸""" | |
if not description: | |
return [] | |
try: | |
results = db.search_stickers(description, limit) | |
if (reranking): | |
# 对搜索结果进行重排 | |
results = StickerService.rerank_search_results(description, results, limit) | |
return results | |
except Exception as e: | |
logger.error(f"Search failed: {str(e)}") | |
return [] | |
def get_all_stickers(limit: int = 1000) -> List[List]: | |
"""获取所有贴纸""" | |
try: | |
results = db.get_all_stickers(limit) | |
return gradio_formatter.format_all_stickers(results) | |
except Exception as e: | |
logger.error(f"Failed to get all stickers: {str(e)}") | |
return [] | |
def delete_sticker(sticker_id: str) -> str: | |
"""删除贴纸""" | |
try: | |
# 首先查询贴纸是否存在 | |
result = db.delete_sticker(sticker_id) | |
return f"Sticker with ID {sticker_id} deleted successfully" | |
except Exception as e: | |
logger.error(f"Delete failed: {str(e)}") | |
return f"Delete failed: {str(e)}" | |
def rerank_search_results(query: str, sticker_list: List[Dict[str, Any]], limit: int = 5) -> List[Dict[str, Any]]: | |
## 使用 LLM 模型重新排序搜索结果 | |
try: | |
# 构建提示词 | |
system_prompt = STICKER_RERANKING_SYSTEM_PROMPT | |
# 构建用户提示词,包含查询和表情包信息 | |
_sticker_list = [] | |
for hit in sticker_list: | |
_sticker_list.append({ | |
"id": hit["id"], | |
"description": hit["entity"]["description"] | |
}) | |
user_prompt = f"请分析关键词 '{query}' 与以下表情包的相关性:\n{_sticker_list}" | |
print(f">>> 使用 LLM 模型重新排序....", user_prompt, system_prompt) | |
# 调用 LLM 模型获取重排序结果 | |
response = get_chat_completion(user_prompt, system_prompt) | |
# 解析 LLM 返回的 JSON 结果 | |
reranked_stickers = json.loads(response) | |
# 验证返回结果格式 | |
if not isinstance(reranked_stickers, list): | |
raise ValueError("Invalid response format") | |
# 按分数排序 | |
reranked_stickers.sort(key=lambda x: float(x.get("score", 0)), reverse=True) | |
print(f">>> LLM 排序结果", reranked_stickers) | |
# 将重排序结果与原始结果对应 | |
rerank_results = [] | |
for sticker in reranked_stickers: | |
for hit in sticker_list: | |
if str(hit["id"]) == str(sticker["sticker_id"]): | |
hit["entity"]["score"] = sticker["score"] | |
hit["entity"]["reason"] = sticker["reason"] | |
rerank_results.append(hit) | |
break | |
print(f">>> rerank_results", rerank_results) | |
return rerank_results | |
except Exception as e: | |
logger.error(f"Reranking failed: {str(e)}") | |
return [] | |
# 创建服务实例 | |
sticker_service = StickerService() |