Toursim-Test / src /api /routes.py
zhuhai111's picture
Upload 43 files
7cc8bc0 verified
from fastapi import APIRouter, HTTPException, Request
from typing import Dict, Any
from pydantic import BaseModel
from src.core.document_processor import DocumentProcessor
from src.core.ranking import RankingSystem
from src.core.plan_generator import PlanGenerator
from src.core.embeddings import EmbeddingModel
from src.core.reranker import Reranker
from src.api.search_api import GoogleSearch, BochaSearch
from src.utils.helpers import setup_proxy
from src.api.llm_api import DeepseekInterface
import logging
logger = logging.getLogger(__name__)
# 创建路由器
router = APIRouter(
tags=["travel"] # Swagger UI 的标签
)
# 添加根路由
@router.get("/")
async def root():
return {
"message": "欢迎使用旅游推荐系统 API",
"status": "运行正常",
"version": "1.0.0",
"endpoints": {
"健康检查": "/health",
"旅游推荐": "/api/v1/recommend",
"API文档": "/docs"
}
}
# 请求模型
class TravelQuery(BaseModel):
query: str
location: str = None
max_results: int = 10
# 响应模型
class TravelResponse(BaseModel):
recommendations: list
query: str
metadata: Dict[str, Any]
def init_app(app):
"""初始化应用"""
# 设置代理并获取代理状态
proxies, proxy_available = setup_proxy(app.state.config)
app.state.proxies = proxies
# 根据代理状态选择搜索引擎
if proxy_available:
app.state.search = GoogleSearch(
api_key=app.state.config['google_api_key'],
cx=app.state.config['google_cx'],
proxies=proxies
)
logging.info("使用 Google 搜索引擎")
else:
app.state.search = BochaSearch(
api_key=app.state.config['bocha_api_key'],
base_url=app.state.config['bocha_base_url']
)
logging.info("使用博查搜索引擎")
# 初始化 Deepseek LLM
app.state.llm = DeepseekInterface(
api_key=app.state.config['deepseek_api_key'],
base_url=app.state.config['llm_settings']['deepseek']['base_url'],
model=app.state.config['llm_settings']['deepseek']['models'][0]
)
# ... 其他初始化代码 ...
@router.post("/api/v1/recommend", response_model=TravelResponse)
async def get_travel_recommendations(query: TravelQuery, request: Request):
"""
获取旅游推荐
"""
logger.info(f"收到查询请求: {query.dict()}")
try:
# 使用已配置代理的搜索实例
search = request.app.state.search
llm = request.app.state.llm
# 执行搜索
logger.info("开始执行搜索...")
search_results = search.search(query.query)
logger.info(f"搜索完成,获得 {len(search_results)} 条结果")
# 处理文档
doc_processor = DocumentProcessor(llm)
passages = doc_processor.process_documents(search_results)
passages = [{'passage': p} for p in passages]
logging.info(f"Passages structure: {passages[:1]}") # 打印第一个元素的结构
# 初始化排序系统
embedding_model = EmbeddingModel("BAAI/bge-m3")
reranker = Reranker("BAAI/bge-reranker-large")
ranking_system = RankingSystem(embedding_model, reranker)
# 两阶段排序
initial_ranked = ranking_system.initial_ranking(
query.query,
passages,
10 # initial_top_k
)
final_ranked = ranking_system.rerank(
query.query,
initial_ranked,
3 # final_top_k
)
# 生成计划
plan_generator = PlanGenerator(llm)
final_plan = plan_generator.generate_plan(query.query, final_ranked)
return TravelResponse(
recommendations=[final_plan['plan']],
query=query.query,
metadata={
"location": query.location,
"max_results": query.max_results,
"sources": final_plan['sources']
}
)
except Exception as e:
logger.error(f"处理请求时发生错误: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# 健康检查端点
@router.get("/health")
async def health_check():
"""
API 健康检查端点
"""
return {"status": "healthy"}