Spaces:
Running
Running
File size: 4,443 Bytes
7cc8bc0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
from fastapi import APIRouter, HTTPException, Request
from typing import Dict, Any
from pydantic import BaseModel
from src.core.document_processor import DocumentProcessor
from src.core.ranking import RankingSystem
from src.core.plan_generator import PlanGenerator
from src.core.embeddings import EmbeddingModel
from src.core.reranker import Reranker
from src.api.search_api import GoogleSearch, BochaSearch
from src.utils.helpers import setup_proxy
from src.api.llm_api import DeepseekInterface
import logging
logger = logging.getLogger(__name__)
# 创建路由器
router = APIRouter(
tags=["travel"] # Swagger UI 的标签
)
# 添加根路由
@router.get("/")
async def root():
return {
"message": "欢迎使用旅游推荐系统 API",
"status": "运行正常",
"version": "1.0.0",
"endpoints": {
"健康检查": "/health",
"旅游推荐": "/api/v1/recommend",
"API文档": "/docs"
}
}
# 请求模型
class TravelQuery(BaseModel):
query: str
location: str = None
max_results: int = 10
# 响应模型
class TravelResponse(BaseModel):
recommendations: list
query: str
metadata: Dict[str, Any]
def init_app(app):
"""初始化应用"""
# 设置代理并获取代理状态
proxies, proxy_available = setup_proxy(app.state.config)
app.state.proxies = proxies
# 根据代理状态选择搜索引擎
if proxy_available:
app.state.search = GoogleSearch(
api_key=app.state.config['google_api_key'],
cx=app.state.config['google_cx'],
proxies=proxies
)
logging.info("使用 Google 搜索引擎")
else:
app.state.search = BochaSearch(
api_key=app.state.config['bocha_api_key'],
base_url=app.state.config['bocha_base_url']
)
logging.info("使用博查搜索引擎")
# 初始化 Deepseek LLM
app.state.llm = DeepseekInterface(
api_key=app.state.config['deepseek_api_key'],
base_url=app.state.config['llm_settings']['deepseek']['base_url'],
model=app.state.config['llm_settings']['deepseek']['models'][0]
)
# ... 其他初始化代码 ...
@router.post("/api/v1/recommend", response_model=TravelResponse)
async def get_travel_recommendations(query: TravelQuery, request: Request):
"""
获取旅游推荐
"""
logger.info(f"收到查询请求: {query.dict()}")
try:
# 使用已配置代理的搜索实例
search = request.app.state.search
llm = request.app.state.llm
# 执行搜索
logger.info("开始执行搜索...")
search_results = search.search(query.query)
logger.info(f"搜索完成,获得 {len(search_results)} 条结果")
# 处理文档
doc_processor = DocumentProcessor(llm)
passages = doc_processor.process_documents(search_results)
passages = [{'passage': p} for p in passages]
logging.info(f"Passages structure: {passages[:1]}") # 打印第一个元素的结构
# 初始化排序系统
embedding_model = EmbeddingModel("BAAI/bge-m3")
reranker = Reranker("BAAI/bge-reranker-large")
ranking_system = RankingSystem(embedding_model, reranker)
# 两阶段排序
initial_ranked = ranking_system.initial_ranking(
query.query,
passages,
10 # initial_top_k
)
final_ranked = ranking_system.rerank(
query.query,
initial_ranked,
3 # final_top_k
)
# 生成计划
plan_generator = PlanGenerator(llm)
final_plan = plan_generator.generate_plan(query.query, final_ranked)
return TravelResponse(
recommendations=[final_plan['plan']],
query=query.query,
metadata={
"location": query.location,
"max_results": query.max_results,
"sources": final_plan['sources']
}
)
except Exception as e:
logger.error(f"处理请求时发生错误: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# 健康检查端点
@router.get("/health")
async def health_check():
"""
API 健康检查端点
"""
return {"status": "healthy"} |