Spaces:
Running
Running
from fastapi import APIRouter, HTTPException, Request | |
from typing import Dict, Any | |
from pydantic import BaseModel | |
from src.core.document_processor import DocumentProcessor | |
from src.core.ranking import RankingSystem | |
from src.core.plan_generator import PlanGenerator | |
from src.core.embeddings import EmbeddingModel | |
from src.core.reranker import Reranker | |
from src.api.search_api import GoogleSearch, BochaSearch | |
from src.utils.helpers import setup_proxy | |
from src.api.llm_api import DeepseekInterface | |
import logging | |
logger = logging.getLogger(__name__) | |
# 创建路由器 | |
router = APIRouter( | |
tags=["travel"] # Swagger UI 的标签 | |
) | |
# 添加根路由 | |
async def root(): | |
return { | |
"message": "欢迎使用旅游推荐系统 API", | |
"status": "运行正常", | |
"version": "1.0.0", | |
"endpoints": { | |
"健康检查": "/health", | |
"旅游推荐": "/api/v1/recommend", | |
"API文档": "/docs" | |
} | |
} | |
# 请求模型 | |
class TravelQuery(BaseModel): | |
query: str | |
location: str = None | |
max_results: int = 10 | |
# 响应模型 | |
class TravelResponse(BaseModel): | |
recommendations: list | |
query: str | |
metadata: Dict[str, Any] | |
def init_app(app): | |
"""初始化应用""" | |
# 设置代理并获取代理状态 | |
proxies, proxy_available = setup_proxy(app.state.config) | |
app.state.proxies = proxies | |
# 根据代理状态选择搜索引擎 | |
if proxy_available: | |
app.state.search = GoogleSearch( | |
api_key=app.state.config['google_api_key'], | |
cx=app.state.config['google_cx'], | |
proxies=proxies | |
) | |
logging.info("使用 Google 搜索引擎") | |
else: | |
app.state.search = BochaSearch( | |
api_key=app.state.config['bocha_api_key'], | |
base_url=app.state.config['bocha_base_url'] | |
) | |
logging.info("使用博查搜索引擎") | |
# 初始化 Deepseek LLM | |
app.state.llm = DeepseekInterface( | |
api_key=app.state.config['deepseek_api_key'], | |
base_url=app.state.config['llm_settings']['deepseek']['base_url'], | |
model=app.state.config['llm_settings']['deepseek']['models'][0] | |
) | |
# ... 其他初始化代码 ... | |
async def get_travel_recommendations(query: TravelQuery, request: Request): | |
""" | |
获取旅游推荐 | |
""" | |
logger.info(f"收到查询请求: {query.dict()}") | |
try: | |
# 使用已配置代理的搜索实例 | |
search = request.app.state.search | |
llm = request.app.state.llm | |
# 执行搜索 | |
logger.info("开始执行搜索...") | |
search_results = search.search(query.query) | |
logger.info(f"搜索完成,获得 {len(search_results)} 条结果") | |
# 处理文档 | |
doc_processor = DocumentProcessor(llm) | |
passages = doc_processor.process_documents(search_results) | |
passages = [{'passage': p} for p in passages] | |
logging.info(f"Passages structure: {passages[:1]}") # 打印第一个元素的结构 | |
# 初始化排序系统 | |
embedding_model = EmbeddingModel("BAAI/bge-m3") | |
reranker = Reranker("BAAI/bge-reranker-large") | |
ranking_system = RankingSystem(embedding_model, reranker) | |
# 两阶段排序 | |
initial_ranked = ranking_system.initial_ranking( | |
query.query, | |
passages, | |
10 # initial_top_k | |
) | |
final_ranked = ranking_system.rerank( | |
query.query, | |
initial_ranked, | |
3 # final_top_k | |
) | |
# 生成计划 | |
plan_generator = PlanGenerator(llm) | |
final_plan = plan_generator.generate_plan(query.query, final_ranked) | |
return TravelResponse( | |
recommendations=[final_plan['plan']], | |
query=query.query, | |
metadata={ | |
"location": query.location, | |
"max_results": query.max_results, | |
"sources": final_plan['sources'] | |
} | |
) | |
except Exception as e: | |
logger.error(f"处理请求时发生错误: {str(e)}", exc_info=True) | |
raise HTTPException(status_code=500, detail=str(e)) | |
# 健康检查端点 | |
async def health_check(): | |
""" | |
API 健康检查端点 | |
""" | |
return {"status": "healthy"} |