File size: 4,443 Bytes
7cc8bc0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from fastapi import APIRouter, HTTPException, Request
from typing import Dict, Any
from pydantic import BaseModel
from src.core.document_processor import DocumentProcessor
from src.core.ranking import RankingSystem
from src.core.plan_generator import PlanGenerator
from src.core.embeddings import EmbeddingModel
from src.core.reranker import Reranker
from src.api.search_api import GoogleSearch, BochaSearch
from src.utils.helpers import setup_proxy
from src.api.llm_api import DeepseekInterface
import logging

logger = logging.getLogger(__name__)

# 创建路由器
router = APIRouter(
    tags=["travel"]    # Swagger UI 的标签
)

# 添加根路由
@router.get("/")
async def root():
    return {
        "message": "欢迎使用旅游推荐系统 API",
        "status": "运行正常",
        "version": "1.0.0",
        "endpoints": {
            "健康检查": "/health",
            "旅游推荐": "/api/v1/recommend",
            "API文档": "/docs"
        }
    }

# 请求模型
class TravelQuery(BaseModel):
    query: str
    location: str = None
    max_results: int = 10

# 响应模型
class TravelResponse(BaseModel):
    recommendations: list
    query: str
    metadata: Dict[str, Any]

def init_app(app):
    """初始化应用"""
    # 设置代理并获取代理状态
    proxies, proxy_available = setup_proxy(app.state.config)
    app.state.proxies = proxies
    
    # 根据代理状态选择搜索引擎
    if proxy_available:
        app.state.search = GoogleSearch(
            api_key=app.state.config['google_api_key'],
            cx=app.state.config['google_cx'],
            proxies=proxies
        )
        logging.info("使用 Google 搜索引擎")
    else:
        app.state.search = BochaSearch(
            api_key=app.state.config['bocha_api_key'],
            base_url=app.state.config['bocha_base_url']
        )
        logging.info("使用博查搜索引擎")
    
    # 初始化 Deepseek LLM
    app.state.llm = DeepseekInterface(
        api_key=app.state.config['deepseek_api_key'],
        base_url=app.state.config['llm_settings']['deepseek']['base_url'],
        model=app.state.config['llm_settings']['deepseek']['models'][0]
    )
    
    # ... 其他初始化代码 ...

@router.post("/api/v1/recommend", response_model=TravelResponse)
async def get_travel_recommendations(query: TravelQuery, request: Request):
    """
    获取旅游推荐
    """
    logger.info(f"收到查询请求: {query.dict()}")
    try:
        # 使用已配置代理的搜索实例
        search = request.app.state.search
        llm = request.app.state.llm
        
        # 执行搜索
        logger.info("开始执行搜索...")
        search_results = search.search(query.query)
        logger.info(f"搜索完成,获得 {len(search_results)} 条结果")
        
        # 处理文档
        doc_processor = DocumentProcessor(llm)
        passages = doc_processor.process_documents(search_results)
        passages = [{'passage': p} for p in passages]
        logging.info(f"Passages structure: {passages[:1]}")  # 打印第一个元素的结构
        
        # 初始化排序系统
        embedding_model = EmbeddingModel("BAAI/bge-m3")
        reranker = Reranker("BAAI/bge-reranker-large")
        ranking_system = RankingSystem(embedding_model, reranker)
        
        # 两阶段排序
        initial_ranked = ranking_system.initial_ranking(
            query.query, 
            passages, 
            10  # initial_top_k
        )
        
        final_ranked = ranking_system.rerank(
            query.query, 
            initial_ranked, 
            3   # final_top_k
        )
        
        # 生成计划
        plan_generator = PlanGenerator(llm)
        final_plan = plan_generator.generate_plan(query.query, final_ranked)
        
        return TravelResponse(
            recommendations=[final_plan['plan']],
            query=query.query,
            metadata={
                "location": query.location,
                "max_results": query.max_results,
                "sources": final_plan['sources']
            }
        )
        
    except Exception as e:
        logger.error(f"处理请求时发生错误: {str(e)}", exc_info=True)
        raise HTTPException(status_code=500, detail=str(e))

# 健康检查端点
@router.get("/health")
async def health_check():
    """
    API 健康检查端点
    """
    return {"status": "healthy"}