Spaces:
Sleeping
Sleeping
import gradio as gr | |
from fastapi import FastAPI, UploadFile, File | |
import uvicorn | |
import logging | |
import os | |
import tempfile | |
# 配置日志 | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s') | |
logger = logging.getLogger(__name__) | |
# 创建FastAPI应用 | |
app = FastAPI() | |
# 使用app/database.py中的数据库实例 | |
from app.database import db | |
def init_milvus(): | |
"""初始化 Milvus 数据库""" | |
db.init_collection() | |
# 图像处理相关功能已经在app/image_utils.py中实现 | |
# 使用app/services.py中的服务 | |
from app.services import sticker_service | |
# 使用app/services.py中的服务,不再需要重复实现这些功能 | |
# 导入UI模块 | |
from app.ui import StickerUI | |
# 创建Gradio界面 | |
# FastAPI 路由 | |
def api_get_stickers(): | |
sticker_list = sticker_service.get_all_stickers() | |
print('>>> GET Sticker_list', sticker_list) | |
return sticker_list | |
async def api_search_stickers(request: dict): | |
description = request.get('description', '') | |
if len(description) > 0: | |
sticker_list = sticker_service.search_stickers( | |
description=description, | |
limit=1, | |
) | |
print('>>> GET Sticker_list', sticker_list) | |
return sticker_list | |
return [] # 当描述为空时返回空列表 | |
async def api_delete_stickers(request: dict): | |
"""Delete sticker by ID""" | |
try: | |
sticker_id = request.get('id') | |
if not sticker_id: | |
return {"status": "error", "message": "Missing sticker ID"} | |
result = sticker_service.delete_sticker(sticker_id) | |
return {"status": "success", "message": result} | |
except Exception as e: | |
logger.error(f"Delete failed: {str(e)}") | |
return {"status": "error", "message": f"Delete failed: {str(e)}"} | |
async def api_import_dataset(file: UploadFile = File(...), upload: bool = False, save_to_milvus: bool = False): | |
"""Import sticker dataset from ZIP file""" | |
try: | |
# Create a temporary file to store the uploaded ZIP | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as temp_file: | |
content = await file.read() | |
temp_file.write(content) | |
temp_file_path = temp_file.name | |
# Import the dataset | |
results = sticker_service.import_stickers(temp_file_path, upload, save_to_milvus) | |
# Clean up the temporary file | |
os.unlink(temp_file_path) | |
return {"status": "success", "results": results} | |
except Exception as e: | |
logger.error(f"Import failed: {str(e)}") | |
return {"status": "error", "message": f"Import failed: {str(e)}"} | |
# 启动应用 | |
if __name__ == "__main__": | |
init_milvus() | |
ui = StickerUI() | |
app = gr.mount_gradio_app(app, ui.create_ui(), '/') | |
uvicorn.run(app, host="0.0.0.0", port=7860) |