Spaces:
Running
Running
Upload 43 files
Browse files- .gitattributes +1 -3
- .gradio/certificate.pem +31 -0
- .huggingface/YAML +9 -0
- README.md +40 -14
- app.py +799 -0
- config/config.yaml +72 -0
- requirements.txt +16 -0
- src/__init__.py +4 -0
- src/__pycache__/__init__.cpython-310.pyc +0 -0
- src/api/__init__.py +47 -0
- src/api/__pycache__/__init__.cpython-310.pyc +0 -0
- src/api/__pycache__/llm_api.cpython-310.pyc +0 -0
- src/api/__pycache__/ollama_api.cpython-310.pyc +0 -0
- src/api/__pycache__/routes.cpython-310.pyc +0 -0
- src/api/__pycache__/search_api.cpython-310.pyc +0 -0
- src/api/llm_api.py +319 -0
- src/api/ollama_api.py +101 -0
- src/api/routes.py +141 -0
- src/api/search_api.py +263 -0
- src/core/__pycache__/document_processor.cpython-310.pyc +0 -0
- src/core/__pycache__/embeddings.cpython-310.pyc +0 -0
- src/core/__pycache__/html_processor.cpython-310.pyc +0 -0
- src/core/__pycache__/plan_generator.cpython-310.pyc +0 -0
- src/core/__pycache__/ranking.cpython-310.pyc +0 -0
- src/core/__pycache__/reranker.cpython-310.pyc +0 -0
- src/core/_init_.py +14 -0
- src/core/document_processor.py +80 -0
- src/core/embeddings.py +41 -0
- src/core/html_processor.py +195 -0
- src/core/plan_generator.py +120 -0
- src/core/ranking.py +114 -0
- src/core/reranker.py +44 -0
- src/retrieval/__pycache__/base.cpython-310.pyc +0 -0
- src/retrieval/__pycache__/graph_rag.cpython-310.pyc +0 -0
- src/retrieval/__pycache__/memo_rag.cpython-310.pyc +0 -0
- src/retrieval/base.py +15 -0
- src/retrieval/graph_rag.py +56 -0
- src/retrieval/memo_rag.py +76 -0
- src/utils/__init__.py +6 -0
- src/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- src/utils/__pycache__/helpers.cpython-310.pyc +0 -0
- src/utils/helpers.py +104 -0
- src/utils/neo4j_helper.py +82 -0
.gitattributes
CHANGED
@@ -23,13 +23,11 @@
|
|
23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
|
|
26 |
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
|
|
27 |
*.tflite filter=lfs diff=lfs merge=lfs -text
|
28 |
*.tgz filter=lfs diff=lfs merge=lfs -text
|
29 |
*.wasm filter=lfs diff=lfs merge=lfs -text
|
30 |
*.xz filter=lfs diff=lfs merge=lfs -text
|
31 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
32 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gradio/certificate.pem
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-----BEGIN CERTIFICATE-----
|
2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
31 |
+
-----END CERTIFICATE-----
|
.huggingface/YAML
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
title: Tourism Planning Assistant
|
2 |
+
emoji: 🌍
|
3 |
+
colorFrom: blue
|
4 |
+
colorTo: indigo
|
5 |
+
sdk: gradio
|
6 |
+
sdk_version: 4.8.0
|
7 |
+
app_file: app.py
|
8 |
+
pinned: true
|
9 |
+
license: mit
|
README.md
CHANGED
@@ -1,14 +1,40 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Tourism Planning Assistant
|
2 |
+
|
3 |
+
旅游规划助手是一个基于 RAG (Retrieval-Augmented Generation) 技术的应用,可以帮助用户规划旅行行程。
|
4 |
+
|
5 |
+
## 功能特点
|
6 |
+
|
7 |
+
- 支持基于用户需求生成个性化旅行计划
|
8 |
+
- 使用搜索引擎获取最新的旅游信息
|
9 |
+
- 支持多种语言模型(OpenAI, Deepseek)
|
10 |
+
- 自动检索和排序相关信息
|
11 |
+
- 支持相关图片展示
|
12 |
+
|
13 |
+
## 部署到 Hugging Face Spaces
|
14 |
+
|
15 |
+
该应用已优化,可以直接部署到 Hugging Face Spaces:
|
16 |
+
|
17 |
+
1. 在 Hugging Face 创建一个新的 Space,选择 Gradio 作为 SDK
|
18 |
+
2. 添加以下环境变量:
|
19 |
+
- `HF_BOCHA_API_KEY`: Bocha API 密钥
|
20 |
+
- `HF_OPENAI_API_KEY`: OpenAI API 密钥 (可选)
|
21 |
+
- `HF_DEEPSEEK_API_KEY`: Deepseek API 密钥 (可选)
|
22 |
+
|
23 |
+
## 使用方法
|
24 |
+
|
25 |
+
1. 输入您的旅行需求(例如:"香港迪士尼一日游")
|
26 |
+
2. 选择旅行天数
|
27 |
+
3. 点击"生成计划"按钮
|
28 |
+
4. 获取个性化旅行计划,包含行程安排和相关参考链接
|
29 |
+
|
30 |
+
## 技术栈
|
31 |
+
|
32 |
+
- Gradio: 用户界面
|
33 |
+
- Sentence Transformers: 文本嵌入
|
34 |
+
- FlagEmbedding: 文本重排序
|
35 |
+
- Hugging Face 模型: BGE-M3 和 BGE-Reranker-Large
|
36 |
+
- Bocha API: 搜索引擎接口
|
37 |
+
|
38 |
+
## 许可证
|
39 |
+
|
40 |
+
MIT License
|
app.py
ADDED
@@ -0,0 +1,799 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import sys
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import List
|
6 |
+
# 添加项目根目录到 Python 路径
|
7 |
+
sys.path.append(str(Path(__file__).parent))
|
8 |
+
from src.api.search_api import BochaSearch
|
9 |
+
from src.core.document_processor import DocumentProcessor
|
10 |
+
from src.core.ranking import RankingSystem
|
11 |
+
from src.core.plan_generator import PlanGenerator
|
12 |
+
from src.core.embeddings import EmbeddingModel
|
13 |
+
from src.core.reranker import Reranker
|
14 |
+
from src.api.llm_api import DeepseekInterface, LLMInterface, OpenAIInterface
|
15 |
+
from src.utils.helpers import load_config
|
16 |
+
import logging
|
17 |
+
|
18 |
+
# 设置日志
|
19 |
+
logging.basicConfig(level=logging.INFO)
|
20 |
+
logger = logging.getLogger(__name__)
|
21 |
+
|
22 |
+
class TravelRAGSystem:
|
23 |
+
def __init__(self):
|
24 |
+
self.config = load_config("config/config.yaml")
|
25 |
+
self.llm_instances = {} # 存储不同provider的LLM实例
|
26 |
+
|
27 |
+
# 固定使用 standard 检索方法
|
28 |
+
self.retrieval_method = "standard"
|
29 |
+
|
30 |
+
self.init_llm_instances()
|
31 |
+
self.init_components()
|
32 |
+
|
33 |
+
def init_components(self):
|
34 |
+
# 获取默认提供商的配置
|
35 |
+
default_provider = self.config['llm_settings']['default_provider']
|
36 |
+
provider_config = next(
|
37 |
+
(p for p in self.config['llm_settings']['providers']
|
38 |
+
if p['name'] == default_provider),
|
39 |
+
None
|
40 |
+
)
|
41 |
+
|
42 |
+
if not provider_config:
|
43 |
+
raise ValueError(f"未找到默认提供商 {default_provider} 的配置")
|
44 |
+
|
45 |
+
# 初始化LLM实例
|
46 |
+
self.llm = self.init_llm(provider_config['name'], provider_config['model'])
|
47 |
+
|
48 |
+
self.search_engine = BochaSearch(
|
49 |
+
api_key=self.config['bocha_api_key'],
|
50 |
+
base_url=self.config['bocha_base_url']
|
51 |
+
)
|
52 |
+
|
53 |
+
self.doc_processor = DocumentProcessor(self.llm)
|
54 |
+
|
55 |
+
# 初始化嵌入模型 - 使用 Hugging Face 模型 ID
|
56 |
+
try:
|
57 |
+
self.embedding_model = EmbeddingModel(
|
58 |
+
model_name="BAAI/bge-m3"
|
59 |
+
)
|
60 |
+
logger.info("成功加载嵌入模型")
|
61 |
+
except Exception as e:
|
62 |
+
logger.error(f"加载嵌入模型失败: {str(e)}")
|
63 |
+
raise
|
64 |
+
|
65 |
+
# 初始化重排序器 - 使用 Hugging Face 模型 ID
|
66 |
+
try:
|
67 |
+
self.reranker = Reranker(
|
68 |
+
model_path="BAAI/bge-reranker-large"
|
69 |
+
)
|
70 |
+
logger.info("成功加载重排序模型")
|
71 |
+
except Exception as e:
|
72 |
+
logger.error(f"加载重排序模型失败: {str(e)}")
|
73 |
+
raise
|
74 |
+
|
75 |
+
self.ranking_system = RankingSystem(self.embedding_model, self.reranker)
|
76 |
+
self.plan_generator = PlanGenerator(self.llm)
|
77 |
+
|
78 |
+
def init_llm(self, provider: str, model: str):
|
79 |
+
if provider == "openai":
|
80 |
+
return OpenAIInterface(
|
81 |
+
api_key=self.config['openai_api_key'],
|
82 |
+
model=model
|
83 |
+
)
|
84 |
+
elif provider == "deepseek":
|
85 |
+
return DeepseekInterface(
|
86 |
+
api_key=self.config['deepseek_api_key'],
|
87 |
+
base_url=next(
|
88 |
+
p['base_url'] for p in self.config['llm_settings']['providers']
|
89 |
+
if p['name'] == 'deepseek'
|
90 |
+
),
|
91 |
+
model=model
|
92 |
+
)
|
93 |
+
else:
|
94 |
+
raise ValueError(f"不支持的LLM提供商: {provider}")
|
95 |
+
|
96 |
+
def init_llm_instances(self):
|
97 |
+
"""初始化所有启用的LLM实例"""
|
98 |
+
for provider in self.config['llm_settings']['providers']:
|
99 |
+
if provider.get('enabled', False):
|
100 |
+
try:
|
101 |
+
if provider['name'] == "openai":
|
102 |
+
self.llm_instances['openai'] = OpenAIInterface(
|
103 |
+
api_key=self.config['openai_api_key'],
|
104 |
+
model=provider['model']
|
105 |
+
)
|
106 |
+
else:
|
107 |
+
self.llm_instances['deepseek'] = DeepseekInterface(
|
108 |
+
api_key=self.config['deepseek_api_key'],
|
109 |
+
base_url=provider['base_url'],
|
110 |
+
model=provider['model']
|
111 |
+
)
|
112 |
+
logging.info(f"成功初始化 {provider['name']} LLM")
|
113 |
+
except Exception as e:
|
114 |
+
logging.error(f"初始化 {provider['name']} LLM 失败: {str(e)}")
|
115 |
+
|
116 |
+
def get_llm(self, provider_name: str = None) -> LLMInterface:
|
117 |
+
"""获取指定的LLM实例"""
|
118 |
+
if not provider_name:
|
119 |
+
provider_name = self.config['llm_settings']['default_provider']
|
120 |
+
|
121 |
+
if provider_name not in self.llm_instances:
|
122 |
+
raise ValueError(f"未找到或未启用的LLM提供商: {provider_name}")
|
123 |
+
|
124 |
+
return self.llm_instances[provider_name]
|
125 |
+
|
126 |
+
def process_query(
|
127 |
+
self,
|
128 |
+
query: str,
|
129 |
+
days: int,
|
130 |
+
llm_provider: str,
|
131 |
+
llm_model: str,
|
132 |
+
enable_images: bool = True,
|
133 |
+
retrieval_method: str = None
|
134 |
+
) -> tuple:
|
135 |
+
try:
|
136 |
+
# 如果指定了新的检索方法,则切换
|
137 |
+
if retrieval_method and retrieval_method != self.retrieval_method:
|
138 |
+
self.set_retrieval_method(retrieval_method)
|
139 |
+
|
140 |
+
# 确保LLM提供商存在
|
141 |
+
if llm_provider not in self.llm_instances:
|
142 |
+
raise ValueError(f"LLM提供商 {llm_provider} 未启用或不可用,将使用默认提供商")
|
143 |
+
current_llm = self.llm_instances[llm_provider]
|
144 |
+
self.doc_processor = DocumentProcessor(current_llm)
|
145 |
+
self.plan_generator = PlanGenerator(current_llm)
|
146 |
+
|
147 |
+
# 确保查询包含天数
|
148 |
+
if days > 0:
|
149 |
+
query = f"{query} {days} days"
|
150 |
+
|
151 |
+
# 执行搜索
|
152 |
+
logger.info(f"执行搜索: {query}")
|
153 |
+
search_results = self.search_engine.search(query)
|
154 |
+
logger.info(f"搜索结果: {search_results}")
|
155 |
+
|
156 |
+
# 处理文档
|
157 |
+
passages = self.doc_processor.process_documents(search_results)
|
158 |
+
logger.info(f"处理后的文档: {passages}")
|
159 |
+
|
160 |
+
# 使用当前检索器进行检索和排序
|
161 |
+
if hasattr(self, 'retriever'):
|
162 |
+
final_ranked = self.retriever.retrieve(query, passages)
|
163 |
+
else:
|
164 |
+
# 使用默认的排序系统
|
165 |
+
initial_ranked = self.ranking_system.initial_ranking(query, passages)
|
166 |
+
final_ranked = self.ranking_system.rerank(query, initial_ranked)
|
167 |
+
|
168 |
+
# 生成计划
|
169 |
+
final_plan = self.plan_generator.generate_plan(query, final_ranked)
|
170 |
+
logger.info(f"生成的计划: {final_plan}")
|
171 |
+
|
172 |
+
# 修改准备参考来源的部分
|
173 |
+
# 创建表格的表头
|
174 |
+
table_header = "| Reference URL | Relevance Score | Retrieval Score | Rerank Score |\n| --- | --- | --- | --- |"
|
175 |
+
|
176 |
+
# 准备表格行
|
177 |
+
table_rows = []
|
178 |
+
for doc in final_ranked:
|
179 |
+
# 如果标题为空,使用URL作为标题
|
180 |
+
title = doc.get('title', '').strip()
|
181 |
+
if not title:
|
182 |
+
from urllib.parse import urlparse
|
183 |
+
domain = urlparse(doc['url']).netloc
|
184 |
+
title = domain
|
185 |
+
|
186 |
+
# 创建表格行
|
187 |
+
row = (
|
188 |
+
f"| [{title}]({doc['url']}) | "
|
189 |
+
f"{doc.get('final_score', 0):.3f} | "
|
190 |
+
f"{doc.get('retrieval_score', 0):.3f} | "
|
191 |
+
f"{doc.get('rerank_score', 0):.3f} |"
|
192 |
+
)
|
193 |
+
table_rows.append(row)
|
194 |
+
|
195 |
+
# 组合表格
|
196 |
+
sources = table_header + "\n" + "\n".join(table_rows)
|
197 |
+
|
198 |
+
logger.info(f"参考来源: {sources}")
|
199 |
+
|
200 |
+
# 修改图片展示部分
|
201 |
+
image_html = ""
|
202 |
+
if enable_images:
|
203 |
+
try:
|
204 |
+
# 增加搜索数量,因为要过滤
|
205 |
+
images = self.search_engine.search_images(query, count=8)
|
206 |
+
valid_images = []
|
207 |
+
|
208 |
+
if images:
|
209 |
+
# 过滤图片
|
210 |
+
for img in images:
|
211 |
+
img_url = img.get('url', '')
|
212 |
+
if img_url and self.verify_image_url(img_url):
|
213 |
+
valid_images.append(img_url)
|
214 |
+
if len(valid_images) >= 3: # 只需要3张有效图片
|
215 |
+
break
|
216 |
+
|
217 |
+
if valid_images: # 如果有有效图片
|
218 |
+
image_html = """
|
219 |
+
<div style="display: flex; flex-direction: column; gap: 15px;">
|
220 |
+
"""
|
221 |
+
for img_url in valid_images:
|
222 |
+
image_html += f"""
|
223 |
+
<div style="border-radius: 8px; overflow: hidden;
|
224 |
+
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.3);">
|
225 |
+
<div style="position: relative; padding-top: 66.67%;">
|
226 |
+
<img src="{img_url}"
|
227 |
+
alt="旅行相关图片"
|
228 |
+
style="position: absolute; top: 0; left: 0; width: 100%;
|
229 |
+
height: 100%; object-fit: cover; transition: transform 0.3s;"
|
230 |
+
onerror="this.style.display='none'">
|
231 |
+
</div>
|
232 |
+
</div>
|
233 |
+
"""
|
234 |
+
image_html += "</div>"
|
235 |
+
except Exception as e:
|
236 |
+
logger.warning(f"获取图片时出现错误: {str(e)}")
|
237 |
+
|
238 |
+
# 美化计划内容的展示
|
239 |
+
plan_content = final_plan['plan']
|
240 |
+
|
241 |
+
# 替换标记符号
|
242 |
+
replacements = {
|
243 |
+
'###': '', # 移除三个#
|
244 |
+
'##': '', # 移除两个#
|
245 |
+
'# ': '', # 移除单个#
|
246 |
+
'**': '', # 移除所有**
|
247 |
+
}
|
248 |
+
|
249 |
+
for old, new in replacements.items():
|
250 |
+
plan_content = plan_content.replace(old, new)
|
251 |
+
|
252 |
+
# 处理标题和段落
|
253 |
+
paragraphs = plan_content.split('\n')
|
254 |
+
formatted_paragraphs = []
|
255 |
+
|
256 |
+
for p in paragraphs:
|
257 |
+
p = p.strip()
|
258 |
+
if not p:
|
259 |
+
continue
|
260 |
+
|
261 |
+
if "Tour Overview" in p:
|
262 |
+
# 主标题样式
|
263 |
+
formatted_paragraphs.append(
|
264 |
+
f'<h2 style="color: #f3f4f6; margin: 10px 0 12px 0; font-size: 1.15em; '
|
265 |
+
f'font-weight: 600; letter-spacing: 0.01em; line-height: 1.3; '
|
266 |
+
f'border-bottom: 1px solid rgba(99, 102, 241, 0.3); padding-bottom: 8px; '
|
267 |
+
f'font-family: -apple-system, BlinkMacSystemFont, Segoe UI, Roboto, Helvetica Neue, sans-serif;">'
|
268 |
+
f'📍 {p}</h2>'
|
269 |
+
)
|
270 |
+
elif ":" in p and any(x in p for x in ["Date", "Destination", "Key Attractions"]):
|
271 |
+
# 关键信息样式,加粗值部分
|
272 |
+
key, value = p.split(":", 1)
|
273 |
+
formatted_paragraphs.append(
|
274 |
+
f'<div style="display: flex; align-items: start; margin: 4px 0; '
|
275 |
+
f'padding-left: 4px;">'
|
276 |
+
f'<span style="color: #818cf8; font-weight: 500; min-width: 70px; '
|
277 |
+
f'font-size: 0.92em;">{key}:</span>'
|
278 |
+
f'<span style="flex: 1; line-height: 1.4; color: #e2e8f0; margin-left: 8px; '
|
279 |
+
f'font-family: -apple-system, BlinkMacSystemFont, Segoe UI, Roboto, Helvetica Neue, sans-serif; '
|
280 |
+
f'font-size: 0.92em; font-weight: 600; letter-spacing: 0.005em;">{value.strip()}</span>'
|
281 |
+
f'</div>'
|
282 |
+
)
|
283 |
+
elif "Daily Itinerary" in p:
|
284 |
+
# 主标题样式
|
285 |
+
formatted_paragraphs.append(
|
286 |
+
f'<h2 style="color: #f3f4f6; margin: 20px 0 12px 0; font-size: 1.15em; '
|
287 |
+
f'font-weight: 600; letter-spacing: 0.01em; line-height: 1.3; '
|
288 |
+
f'border-bottom: 1px solid rgba(99, 102, 241, 0.3); padding-bottom: 8px; '
|
289 |
+
f'font-family: -apple-system, BlinkMacSystemFont, Segoe UI, Roboto, Helvetica Neue, sans-serif;">'
|
290 |
+
f'🕒️ {p}</h2>'
|
291 |
+
)
|
292 |
+
elif " - " in p: # 时间段标题
|
293 |
+
# 子标题样式
|
294 |
+
formatted_paragraphs.append(
|
295 |
+
f'<h3 style="color: #e2e8f0; margin: 14px 0 6px 0; font-size: 1.05em; '
|
296 |
+
f'font-weight: 600; letter-spacing: 0.01em; line-height: 1.4; padding-bottom: 4px; '
|
297 |
+
f'font-family: -apple-system, BlinkMacSystemFont, Segoe UI, Roboto, Helvetica Neue, sans-serif;">'
|
298 |
+
f'🕒 {p}</h3>'
|
299 |
+
)
|
300 |
+
elif p.startswith("Location") or p.startswith("Activity") or p.startswith("Transportation") or p.startswith("Specific Guidance"):
|
301 |
+
# 信息样式
|
302 |
+
key, value = p.split(":", 1)
|
303 |
+
icon = {
|
304 |
+
"Location": "📍",
|
305 |
+
"Activity": "🎯",
|
306 |
+
"Transportation": "🚇",
|
307 |
+
"Specific Guidance": "🗺️"
|
308 |
+
}.get(key, "•")
|
309 |
+
|
310 |
+
formatted_paragraphs.append(
|
311 |
+
f'<div style="display: flex; align-items: start; margin: 4px 0; padding-left: 4px;">'
|
312 |
+
f'<span style="color: #818cf8; margin-right: 8px;">{icon}</span>'
|
313 |
+
f'<span style="flex: 1; line-height: 1.4; color: #e2e8f0; '
|
314 |
+
f'font-family: -apple-system, BlinkMacSystemFont, Segoe UI, Roboto, Helvetica Neue, sans-serif; '
|
315 |
+
f'font-size: 0.92em; font-weight: 400; letter-spacing: 0.005em;">{value.strip()}</span>'
|
316 |
+
f'</div>'
|
317 |
+
)
|
318 |
+
else:
|
319 |
+
# 普通段落样式
|
320 |
+
formatted_paragraphs.append(
|
321 |
+
f'<p style="margin: 8px 0; line-height: 1.5; color: #e2e8f0; '
|
322 |
+
f'font-family: -apple-system, BlinkMacSystemFont, Segoe UI, Roboto, Helvetica Neue, sans-serif; '
|
323 |
+
f'font-size: 0.95em; font-weight: 400; letter-spacing: 0.005em;">{p}</p>'
|
324 |
+
)
|
325 |
+
|
326 |
+
plan_content = '\n'.join(formatted_paragraphs)
|
327 |
+
|
328 |
+
# 将所有内容包装在一个暗色主题的容器中
|
329 |
+
final_output = f"""
|
330 |
+
<div style="max-width: 100%; padding: 24px; background: rgba(17, 24, 39, 0.7);
|
331 |
+
border-radius: 16px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.2);">
|
332 |
+
<div style="margin-top: 20px;">
|
333 |
+
{plan_content}
|
334 |
+
</div>
|
335 |
+
</div>
|
336 |
+
"""
|
337 |
+
|
338 |
+
return final_output, sources, image_html
|
339 |
+
|
340 |
+
except Exception as e:
|
341 |
+
logger.error(f"Error processing query: {str(e)}")
|
342 |
+
return f"Sorry, an error occurred while processing your request: {str(e)}", "", ""
|
343 |
+
|
344 |
+
def verify_image_url(self, url: str) -> bool:
|
345 |
+
"""验证图片URL是否可访问且符合要求"""
|
346 |
+
try:
|
347 |
+
import requests
|
348 |
+
from PIL import Image
|
349 |
+
import io
|
350 |
+
import numpy as np
|
351 |
+
from PIL import ImageDraw, ImageFont
|
352 |
+
|
353 |
+
# 获取图片
|
354 |
+
response = requests.get(url, timeout=3)
|
355 |
+
if response.status_code != 200:
|
356 |
+
return False
|
357 |
+
|
358 |
+
# 检查内容类型
|
359 |
+
content_type = response.headers.get('content-type', '')
|
360 |
+
if 'image' not in content_type.lower():
|
361 |
+
return False
|
362 |
+
|
363 |
+
# 读取图片
|
364 |
+
img = Image.open(io.BytesIO(response.content))
|
365 |
+
|
366 |
+
# 1. 检查图片尺寸
|
367 |
+
width, height = img.size
|
368 |
+
if width < 300 or height < 300: # 过滤掉太小的图片
|
369 |
+
return False
|
370 |
+
|
371 |
+
# 2. 检查宽高比
|
372 |
+
aspect_ratio = width / height
|
373 |
+
if aspect_ratio < 0.5 or aspect_ratio > 2.0: # 过滤掉比例不合适的图片
|
374 |
+
return False
|
375 |
+
|
376 |
+
# 3. 转换为numpy数组进行分析
|
377 |
+
img_array = np.array(img)
|
378 |
+
|
379 |
+
# 4. 检查图片是否过于单调(可能是纯文字图)
|
380 |
+
if len(img_array.shape) == 3: # 确保是彩色图片
|
381 |
+
std = np.std(img_array)
|
382 |
+
if std < 30: # 标准差太小说明图片太单调
|
383 |
+
return False
|
384 |
+
|
385 |
+
# 5. 检测文字区域(简单实现)
|
386 |
+
# 转换为灰度图
|
387 |
+
if img.mode != 'L':
|
388 |
+
img_gray = img.convert('L')
|
389 |
+
else:
|
390 |
+
img_gray = img
|
391 |
+
|
392 |
+
# 计算边缘密度
|
393 |
+
from PIL import ImageFilter
|
394 |
+
edges = img_gray.filter(ImageFilter.FIND_EDGES)
|
395 |
+
edge_density = np.mean(np.array(edges))
|
396 |
+
|
397 |
+
# 如果边缘密度太高,可能包含大量文字
|
398 |
+
if edge_density > 30:
|
399 |
+
return False
|
400 |
+
|
401 |
+
# 6. 检查图片是否过于饱和(可能是广告图)
|
402 |
+
if len(img_array.shape) == 3:
|
403 |
+
hsv = img.convert('HSV')
|
404 |
+
saturation = np.array(hsv)[:,:,1]
|
405 |
+
if np.mean(saturation) > 200: # 饱和度过高
|
406 |
+
return False
|
407 |
+
|
408 |
+
return True
|
409 |
+
|
410 |
+
except Exception as e:
|
411 |
+
logger.warning(f"图片验证失败: {str(e)}")
|
412 |
+
return False
|
413 |
+
|
414 |
+
def _format_images_html(self, images: List[str]) -> str:
|
415 |
+
"""格式化图片HTML展示"""
|
416 |
+
if not images:
|
417 |
+
return ""
|
418 |
+
|
419 |
+
# 使用flex布局来展示图片
|
420 |
+
html = """
|
421 |
+
<div style="display: flex; flex-wrap: wrap; gap: 10px; justify-content: center; margin-top: 20px;">
|
422 |
+
"""
|
423 |
+
|
424 |
+
for img_url in images:
|
425 |
+
# 添加图片容器和加载失败处理
|
426 |
+
html += f"""
|
427 |
+
<div style="flex: 0 0 calc(50% - 10px); max-width: 300px; min-width: 200px;">
|
428 |
+
<img
|
429 |
+
src="{img_url}"
|
430 |
+
style="width: 100%; height: 200px; object-fit: cover; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);"
|
431 |
+
onerror="this.onerror=null; this.src='https://via.placeholder.com/300x200?text=Image+Not+Available';"
|
432 |
+
/>
|
433 |
+
</div>
|
434 |
+
"""
|
435 |
+
|
436 |
+
html += "</div>"
|
437 |
+
|
438 |
+
# 添加调试日志
|
439 |
+
logger.info(f"生成的图片HTML: {html[:200]}...") # 只打印前200个字符
|
440 |
+
|
441 |
+
return html
|
442 |
+
|
443 |
+
def set_retrieval_method(self, method: str):
|
444 |
+
"""切换检索方法"""
|
445 |
+
if method not in ["standard"]:
|
446 |
+
raise ValueError(f"不支持的检索方法: {method}")
|
447 |
+
|
448 |
+
self.retrieval_method = method
|
449 |
+
|
450 |
+
# 根据方法初始化对应的检索器
|
451 |
+
if method == "standard":
|
452 |
+
self.retriever = self.ranking_system
|
453 |
+
|
454 |
+
def create_interface():
|
455 |
+
system = TravelRAGSystem()
|
456 |
+
|
457 |
+
# 获取已启用的提供商列表
|
458 |
+
enabled_providers = [
|
459 |
+
provider['name']
|
460 |
+
for provider in system.config['llm_settings']['providers']
|
461 |
+
if provider['enabled']
|
462 |
+
]
|
463 |
+
|
464 |
+
# 创建提供商和模型的映射
|
465 |
+
provider_models = {
|
466 |
+
provider['name']: provider['models']
|
467 |
+
for provider in system.config['llm_settings']['providers']
|
468 |
+
if provider['enabled']
|
469 |
+
}
|
470 |
+
|
471 |
+
# 创建界面并设置自定义CSS
|
472 |
+
css = """
|
473 |
+
.gradio-container {
|
474 |
+
font-family: "PingFang SC", "Microsoft YaHei", sans-serif;
|
475 |
+
}
|
476 |
+
|
477 |
+
/* 针对所有英文文本 */
|
478 |
+
[class*="message-"] {
|
479 |
+
font-family: 'Times New Roman', serif !important;
|
480 |
+
}
|
481 |
+
|
482 |
+
/* 确保英文和数字使用 Times New Roman */
|
483 |
+
.gradio-container *:not(:lang(zh)) {
|
484 |
+
font-family: 'Times New Roman', serif !important;
|
485 |
+
}
|
486 |
+
|
487 |
+
@keyframes spin {
|
488 |
+
0% { transform: rotate(0deg); }
|
489 |
+
100% { transform: rotate(360deg); }
|
490 |
+
}
|
491 |
+
|
492 |
+
/* 隐藏数字输入框的上下箭头 */
|
493 |
+
input[type="number"]::-webkit-inner-spin-button,
|
494 |
+
input[type="number"]::-webkit-outer-spin-button {
|
495 |
+
-webkit-appearance: none;
|
496 |
+
margin: 0;
|
497 |
+
}
|
498 |
+
|
499 |
+
input[type="number"] {
|
500 |
+
-moz-appearance: textfield;
|
501 |
+
}
|
502 |
+
|
503 |
+
/* 隐藏默认的 processing 信息和箭头 */
|
504 |
+
.progress-text, .meta-text-center, .progress-container {
|
505 |
+
display: none !important;
|
506 |
+
}
|
507 |
+
|
508 |
+
/* 修改加载动画样式 */
|
509 |
+
.loading {
|
510 |
+
display: flex;
|
511 |
+
align-items: center;
|
512 |
+
justify-content: center;
|
513 |
+
gap: 8px;
|
514 |
+
font-size: 1.2em;
|
515 |
+
color: rgb(192, 192, 255);
|
516 |
+
}
|
517 |
+
|
518 |
+
.loading::before {
|
519 |
+
content: '🌍';
|
520 |
+
display: inline-block;
|
521 |
+
animation: spin 2s linear infinite;
|
522 |
+
filter: brightness(1.5); /* 让地球图标更亮 */
|
523 |
+
}
|
524 |
+
|
525 |
+
/* 调整 Gradio 默认加载动画的位置 */
|
526 |
+
.progress-text {
|
527 |
+
display: block !important;
|
528 |
+
order: 3;
|
529 |
+
margin-top: 8px;
|
530 |
+
opacity: 0.7;
|
531 |
+
}
|
532 |
+
|
533 |
+
.meta-text-center {
|
534 |
+
display: block !important;
|
535 |
+
}
|
536 |
+
|
537 |
+
/* 确保加载容器使用 flex 布局 */
|
538 |
+
.loading-container {
|
539 |
+
display: flex;
|
540 |
+
flex-direction: column;
|
541 |
+
align-items: center;
|
542 |
+
}
|
543 |
+
|
544 |
+
/* 隐藏滑块右侧的上下箭头 */
|
545 |
+
.num-input-plus, .num-input-minus {
|
546 |
+
display: none !important;
|
547 |
+
}
|
548 |
+
|
549 |
+
/* 隐藏所有滚动箭头 */
|
550 |
+
.scroll-hide,
|
551 |
+
.output-markdown,
|
552 |
+
.output-text,
|
553 |
+
.markdown-text,
|
554 |
+
.prose,
|
555 |
+
.gr-box,
|
556 |
+
.gr-panel {
|
557 |
+
-ms-overflow-style: none !important;
|
558 |
+
scrollbar-width: none !important;
|
559 |
+
overflow-y: hidden !important;
|
560 |
+
overflow: hidden !important;
|
561 |
+
}
|
562 |
+
|
563 |
+
.scroll-hide::-webkit-scrollbar,
|
564 |
+
.output-markdown::-webkit-scrollbar,
|
565 |
+
.output-text::-webkit-scrollbar,
|
566 |
+
.markdown-text::-webkit-scrollbar,
|
567 |
+
.prose::-webkit-scrollbar,
|
568 |
+
.gr-box::-webkit-scrollbar,
|
569 |
+
.gr-panel::-webkit-scrollbar {
|
570 |
+
display: none !important;
|
571 |
+
width: 0 !important;
|
572 |
+
height: 0 !important;
|
573 |
+
}
|
574 |
+
|
575 |
+
/* 修改加载动画容器样式 */
|
576 |
+
.loading-container {
|
577 |
+
overflow: hidden !important;
|
578 |
+
min-height: 60px;
|
579 |
+
}
|
580 |
+
|
581 |
+
/* 隐藏 Gradio 默认的滚动控件 */
|
582 |
+
.wrap.svelte-byatnx,
|
583 |
+
.contain.svelte-byatnx,
|
584 |
+
[class*='svelte'],
|
585 |
+
.gradio-container {
|
586 |
+
overflow: hidden !important;
|
587 |
+
overflow-y: hidden !important;
|
588 |
+
}
|
589 |
+
|
590 |
+
/* 禁用所有可能的滚动控件 */
|
591 |
+
::-webkit-scrollbar {
|
592 |
+
display: none !important;
|
593 |
+
width: 0 !important;
|
594 |
+
height: 0 !important;
|
595 |
+
}
|
596 |
+
|
597 |
+
/* 移除 Group 组件的默认背景 */
|
598 |
+
.custom-group {
|
599 |
+
border: none !important;
|
600 |
+
background: none !important;
|
601 |
+
box-shadow: none !important;
|
602 |
+
}
|
603 |
+
|
604 |
+
.custom-group > div {
|
605 |
+
border: none !important;
|
606 |
+
background: none !important;
|
607 |
+
box-shadow: none !important;
|
608 |
+
}
|
609 |
+
|
610 |
+
/* 添加图片容器样式 */
|
611 |
+
.images-container {
|
612 |
+
margin-top: 20px;
|
613 |
+
padding: 10px;
|
614 |
+
background: #fff;
|
615 |
+
border-radius: 8px;
|
616 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
617 |
+
}
|
618 |
+
|
619 |
+
.images-container img {
|
620 |
+
transition: transform 0.3s ease;
|
621 |
+
}
|
622 |
+
|
623 |
+
.images-container img:hover {
|
624 |
+
transform: scale(1.05);
|
625 |
+
}
|
626 |
+
|
627 |
+
/* 确保图片容器可见 */
|
628 |
+
#component-13 {
|
629 |
+
min-height: 200px;
|
630 |
+
overflow: visible !important;
|
631 |
+
}
|
632 |
+
"""
|
633 |
+
|
634 |
+
# 修改 JavaScript 加载状态文本
|
635 |
+
js = """
|
636 |
+
function showLoading() {
|
637 |
+
document.getElementById('loading_status').innerHTML = '<p class="loading">Generating your personalized travel plan...</p>';
|
638 |
+
return ['', ''];
|
639 |
+
}
|
640 |
+
"""
|
641 |
+
|
642 |
+
with gr.Blocks(theme=gr.themes.Soft(), css=css) as interface:
|
643 |
+
gr.Markdown("""
|
644 |
+
# 🌟 Tourism Planning Assistant 🌟
|
645 |
+
|
646 |
+
Welcome to the Smart Travel Planning Assistant! Simply input your travel requirements, and we'll generate a personalized travel plan for you.
|
647 |
+
|
648 |
+
### Instructions
|
649 |
+
1. Describe your travel needs in the input box (e.g., 'One-day trip to Hong Kong Disneyland')
|
650 |
+
2. Select the number of days for your plan
|
651 |
+
3. Click the "Generate Plan" button
|
652 |
+
""")
|
653 |
+
|
654 |
+
with gr.Row():
|
655 |
+
with gr.Column(scale=4):
|
656 |
+
llm_provider = gr.Dropdown(
|
657 |
+
choices=enabled_providers,
|
658 |
+
value=system.config['llm_settings']['default_provider'],
|
659 |
+
label="Select LLM Provider"
|
660 |
+
)
|
661 |
+
llm_model = gr.Dropdown(
|
662 |
+
choices=provider_models[system.config['llm_settings']['default_provider']],
|
663 |
+
label="Select Model"
|
664 |
+
)
|
665 |
+
|
666 |
+
# 添加更新模型选择的函数
|
667 |
+
def update_model_choices(provider):
|
668 |
+
return gr.Dropdown(choices=provider_models[provider])
|
669 |
+
|
670 |
+
# 设置提供商改变时的回调
|
671 |
+
llm_provider.change(
|
672 |
+
fn=update_model_choices,
|
673 |
+
inputs=[llm_provider],
|
674 |
+
outputs=[llm_model]
|
675 |
+
)
|
676 |
+
|
677 |
+
query_input = gr.Textbox(
|
678 |
+
label="Travel Requirements",
|
679 |
+
placeholder="Please enter your travel requirements, e.g.: One-day trip to Hong Kong Disneyland",
|
680 |
+
lines=2
|
681 |
+
)
|
682 |
+
days_input = gr.Slider(
|
683 |
+
minimum=1,
|
684 |
+
maximum=7,
|
685 |
+
value=1,
|
686 |
+
step=1,
|
687 |
+
label="Number of Days"
|
688 |
+
)
|
689 |
+
|
690 |
+
# 添加显示图片的复选框
|
691 |
+
show_images = gr.Checkbox(
|
692 |
+
label="Search Related Images",
|
693 |
+
value=True,
|
694 |
+
info="Whether to search and display related reference images"
|
695 |
+
)
|
696 |
+
|
697 |
+
# 移除 memorag 和 graphrag 选项,只保留 standard
|
698 |
+
retrieval_method = gr.Radio(
|
699 |
+
choices=["standard"],
|
700 |
+
value="standard",
|
701 |
+
label="Retrieval Method",
|
702 |
+
info="Choose different retrieval strategies",
|
703 |
+
visible=False # 由于只有一个选项,可以直接隐藏
|
704 |
+
)
|
705 |
+
|
706 |
+
submit_btn = gr.Button("Generate Plan", variant="primary")
|
707 |
+
loading_status = gr.Markdown("", elem_id="loading_status", show_label=False)
|
708 |
+
|
709 |
+
# 添加图片展示区域到左侧列
|
710 |
+
images_container = gr.HTML(
|
711 |
+
value="", # 确保初始值为空字符串
|
712 |
+
visible=True,
|
713 |
+
label="Related Images"
|
714 |
+
)
|
715 |
+
|
716 |
+
# 当复选框状态改变时更新图片区域的显示状态
|
717 |
+
show_images.change(
|
718 |
+
fn=lambda x: "" if not x else "<div></div>", # 当禁用图片时返回空字符串
|
719 |
+
inputs=[show_images],
|
720 |
+
outputs=[images_container]
|
721 |
+
)
|
722 |
+
|
723 |
+
with gr.Column(scale=6):
|
724 |
+
with gr.Tabs():
|
725 |
+
with gr.TabItem("Travel Plan"):
|
726 |
+
plan_output = gr.HTML(label="Generated Travel Plan", show_label=False)
|
727 |
+
with gr.TabItem("References and Evaluation"):
|
728 |
+
sources_output = gr.Markdown(label="References and Evaluation", show_label=False)
|
729 |
+
|
730 |
+
# 修改示例为英文
|
731 |
+
gr.Examples(
|
732 |
+
examples=[
|
733 |
+
["One-day trip to Hong Kong Disneyland", 1],
|
734 |
+
["Family trip to Hong Kong Ocean Park", 1],
|
735 |
+
["Hong Kong Shopping and Food Tour", 2],
|
736 |
+
["Hong Kong Cultural Experience Tour", 3]
|
737 |
+
],
|
738 |
+
inputs=[query_input, days_input],
|
739 |
+
label="Example Queries"
|
740 |
+
)
|
741 |
+
|
742 |
+
def show_loading():
|
743 |
+
loading_html = "<div class='loading-container'><p class='loading'>Generating your personalized travel plan...</p></div>"
|
744 |
+
return loading_html, loading_html, "", ""
|
745 |
+
|
746 |
+
def process_with_images(query, days, llm_provider, llm_model, enable_images, retrieval_method):
|
747 |
+
plan_html, sources_md, images_html = system.process_query(
|
748 |
+
query, days, llm_provider, llm_model,
|
749 |
+
enable_images, retrieval_method
|
750 |
+
)
|
751 |
+
|
752 |
+
# 添加调试日志
|
753 |
+
logger.info(f"图片HTML长度: {len(images_html) if images_html else 0}")
|
754 |
+
|
755 |
+
return plan_html, sources_md, images_html
|
756 |
+
|
757 |
+
# 设置提交按钮事件
|
758 |
+
submit_btn.click(
|
759 |
+
fn=show_loading,
|
760 |
+
inputs=None,
|
761 |
+
outputs=[loading_status, plan_output, sources_output, images_container]
|
762 |
+
).then(
|
763 |
+
fn=process_with_images,
|
764 |
+
inputs=[
|
765 |
+
query_input,
|
766 |
+
days_input,
|
767 |
+
llm_provider,
|
768 |
+
llm_model,
|
769 |
+
show_images,
|
770 |
+
retrieval_method
|
771 |
+
],
|
772 |
+
outputs=[plan_output, sources_output, images_container] # 确保顺序正确
|
773 |
+
).then(
|
774 |
+
fn=lambda: "",
|
775 |
+
inputs=None,
|
776 |
+
outputs=[loading_status]
|
777 |
+
)
|
778 |
+
|
779 |
+
# 修改页脚为英文
|
780 |
+
gr.Markdown("""
|
781 |
+
### 📝 Notes
|
782 |
+
- Plan generation may take some time, please be patient
|
783 |
+
- Queries should include specific locations and activity preferences
|
784 |
+
- All plans are AI-generated, please adjust according to actual circumstances
|
785 |
+
|
786 |
+
Powered by RAG for Tourism system © 2024
|
787 |
+
""")
|
788 |
+
|
789 |
+
return interface
|
790 |
+
|
791 |
+
if __name__ == "__main__":
|
792 |
+
demo = create_interface()
|
793 |
+
# 使用 Hugging Face Spaces 环境变量
|
794 |
+
demo.launch(
|
795 |
+
server_name="0.0.0.0",
|
796 |
+
server_port=7860,
|
797 |
+
share=False, # Hugging Face Spaces 已经提供了公开访问
|
798 |
+
debug=False
|
799 |
+
)
|
config/config.yaml
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# API Keys - 在 Hugging Face Spaces 中使用环境变量方式设置
|
2 |
+
google_api_key: ${HF_GOOGLE_API_KEY}
|
3 |
+
google_cx: ${HF_GOOGLE_CX}
|
4 |
+
bing_api_key: ${HF_BING_API_KEY}
|
5 |
+
openai_api_key: ${HF_OPENAI_API_KEY}
|
6 |
+
deepseek_api_key: ${HF_DEEPSEEK_API_KEY}
|
7 |
+
bocha_api_key: ${HF_BOCHA_API_KEY}
|
8 |
+
bocha_base_url: "https://api.bochaai.com"
|
9 |
+
|
10 |
+
# LLM Settings
|
11 |
+
llm_settings:
|
12 |
+
providers:
|
13 |
+
- name: "deepseek"
|
14 |
+
enabled: true
|
15 |
+
model: "deepseek-chat"
|
16 |
+
base_url: "https://api.deepseek.com"
|
17 |
+
api_key: ${deepseek_api_key}
|
18 |
+
models: ["deepseek-chat"]
|
19 |
+
- name: "openai"
|
20 |
+
enabled: true
|
21 |
+
model: "gpt-4o"
|
22 |
+
base_url: "https://api.openai.com/v1" # 使用标准 OpenAI API URL
|
23 |
+
api_key: ${openai_api_key}
|
24 |
+
models: ["gpt-4o"]
|
25 |
+
default_provider: "deepseek" # 默认使用的提供商
|
26 |
+
|
27 |
+
# 检索设置
|
28 |
+
retrieval_settings:
|
29 |
+
default_method: "standard"
|
30 |
+
methods:
|
31 |
+
- name: "standard"
|
32 |
+
enabled: true
|
33 |
+
model_settings:
|
34 |
+
embedding_model: "BAAI/bge-m3" # 使用 Hugging Face 模型 ID
|
35 |
+
reranker_model: "BAAI/bge-reranker-large" # 使用 Hugging Face 模型 ID
|
36 |
+
|
37 |
+
# Search Settings
|
38 |
+
max_results: 20
|
39 |
+
language: "zh-CN"
|
40 |
+
search_provider: "bocha"
|
41 |
+
|
42 |
+
# Document Processing
|
43 |
+
max_passage_length: 500
|
44 |
+
min_passage_length: 100
|
45 |
+
|
46 |
+
# Vector Search Settings
|
47 |
+
embedding_model: "BAAI/bge-m3" # 使用 Hugging Face 模型 ID
|
48 |
+
reranker_model: "BAAI/bge-reranker-large" # 使用 Hugging Face 模型 ID
|
49 |
+
batch_size: 32
|
50 |
+
use_gpu: true
|
51 |
+
|
52 |
+
# Ranking Settings
|
53 |
+
initial_top_k: 100
|
54 |
+
final_top_k: 3
|
55 |
+
retrieval_weight: 0.3
|
56 |
+
rerank_weight: 0.7
|
57 |
+
|
58 |
+
# Search Settings
|
59 |
+
search_settings:
|
60 |
+
trusted_domains:
|
61 |
+
- 'discoverhongkong.com'
|
62 |
+
- 'tourism.gov.hk'
|
63 |
+
- 'hong-kong-travel.com'
|
64 |
+
- 'timeout.com.hk'
|
65 |
+
- 'openrice.com'
|
66 |
+
- 'lcsd.gov.hk'
|
67 |
+
- 'hkpl.gov.hk'
|
68 |
+
|
69 |
+
proxy:
|
70 |
+
enabled: false
|
71 |
+
host: "127.0.0.1"
|
72 |
+
port: 8880
|
requirements.txt
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
requests>=2.31.0
|
2 |
+
beautifulsoup4>=4.12.0
|
3 |
+
trafilatura>=1.6.1
|
4 |
+
torch>=2.0.0
|
5 |
+
transformers>=4.36.0
|
6 |
+
openai>=1.3.0
|
7 |
+
pyyaml>=6.0.1
|
8 |
+
faiss-cpu>=1.7.4
|
9 |
+
sentence-transformers>=2.2.0
|
10 |
+
gradio>=4.8.0
|
11 |
+
neo4j>=5.14.0
|
12 |
+
langchain>=0.0.350
|
13 |
+
matplotlib>=3.8.0
|
14 |
+
Pillow>=9.0.0
|
15 |
+
numpy>=1.22.0
|
16 |
+
FlagEmbedding>=1.1.5
|
src/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .core.html_processor import HTMLProcessor
|
2 |
+
from .core.document_processor import DocumentProcessor
|
3 |
+
|
4 |
+
__all__ = ['HTMLProcessor', 'DocumentProcessor']
|
src/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (291 Bytes). View file
|
|
src/api/__init__.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Any
|
2 |
+
from fastapi import FastAPI
|
3 |
+
from fastapi.middleware.cors import CORSMiddleware
|
4 |
+
from .llm_api import DeepseekInterface
|
5 |
+
from .search_api import GoogleSearch
|
6 |
+
|
7 |
+
def create_app(config: Dict[str, Any] = None) -> FastAPI:
|
8 |
+
"""
|
9 |
+
创建并配置 FastAPI 应用
|
10 |
+
"""
|
11 |
+
app = FastAPI(
|
12 |
+
title="Travel RAG API",
|
13 |
+
description="Travel recommendation system using RAG",
|
14 |
+
version="1.0.0"
|
15 |
+
)
|
16 |
+
|
17 |
+
# 配置 CORS
|
18 |
+
app.add_middleware(
|
19 |
+
CORSMiddleware,
|
20 |
+
allow_origins=["*"],
|
21 |
+
allow_credentials=True,
|
22 |
+
allow_methods=["*"],
|
23 |
+
allow_headers=["*"],
|
24 |
+
)
|
25 |
+
|
26 |
+
# 初始化配置
|
27 |
+
if config:
|
28 |
+
app.state.config = config
|
29 |
+
|
30 |
+
# 初始化 LLM
|
31 |
+
app.state.llm = DeepseekInterface(
|
32 |
+
api_key=config['deepseek_api_key'],
|
33 |
+
base_url=config['llm_settings']['deepseek']['base_url'],
|
34 |
+
model=config['llm_settings']['deepseek']['models'][0]
|
35 |
+
)
|
36 |
+
|
37 |
+
# 初始化搜索引擎在 init_app 中完成
|
38 |
+
from .routes import init_app
|
39 |
+
init_app(app)
|
40 |
+
|
41 |
+
# 导入和注册路由
|
42 |
+
from .routes import router
|
43 |
+
app.include_router(router)
|
44 |
+
|
45 |
+
return app
|
46 |
+
|
47 |
+
__all__ = ['create_app']
|
src/api/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (1.17 kB). View file
|
|
src/api/__pycache__/llm_api.cpython-310.pyc
ADDED
Binary file (8.51 kB). View file
|
|
src/api/__pycache__/ollama_api.cpython-310.pyc
ADDED
Binary file (3.59 kB). View file
|
|
src/api/__pycache__/routes.cpython-310.pyc
ADDED
Binary file (2.91 kB). View file
|
|
src/api/__pycache__/search_api.cpython-310.pyc
ADDED
Binary file (6.3 kB). View file
|
|
src/api/llm_api.py
ADDED
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict
|
2 |
+
from abc import ABC, abstractmethod
|
3 |
+
from openai import OpenAI
|
4 |
+
import logging
|
5 |
+
import httpx
|
6 |
+
|
7 |
+
class LLMInterface(ABC):
|
8 |
+
@abstractmethod
|
9 |
+
def generate(self, prompt: str) -> str:
|
10 |
+
pass
|
11 |
+
|
12 |
+
class DeepseekInterface(LLMInterface):
|
13 |
+
def __init__(self, api_key: str, base_url: str, model: str):
|
14 |
+
self.api_key = api_key
|
15 |
+
self.base_url = base_url
|
16 |
+
self.model = model
|
17 |
+
self.headers = {
|
18 |
+
"Authorization": f"Bearer {api_key}",
|
19 |
+
"Content-Type": "application/json"
|
20 |
+
}
|
21 |
+
|
22 |
+
def _build_system_prompt(self, role: str) -> str:
|
23 |
+
"""构建系统提示词"""
|
24 |
+
roles = {
|
25 |
+
"summarizer": "You are a professional tourism content analyst, good at extracting and summarizing key tourism-related information. Please answer in English",
|
26 |
+
"planner": "You are a professional travel planner who is good at making detailed travel plans. Please answer in English"
|
27 |
+
}
|
28 |
+
return roles.get(role, "You are a professional AI assistant. Please answer in English")
|
29 |
+
|
30 |
+
def generate(self, prompt: str, role: str = "planner") -> str:
|
31 |
+
import requests
|
32 |
+
from requests.adapters import HTTPAdapter
|
33 |
+
from urllib3.util.retry import Retry
|
34 |
+
|
35 |
+
session = requests.Session()
|
36 |
+
retries = Retry(
|
37 |
+
total=3,
|
38 |
+
backoff_factor=1,
|
39 |
+
status_forcelist=[500, 502, 503, 504]
|
40 |
+
)
|
41 |
+
session.mount('https://', HTTPAdapter(max_retries=retries))
|
42 |
+
|
43 |
+
payload = {
|
44 |
+
"model": self.model,
|
45 |
+
"messages": [
|
46 |
+
{
|
47 |
+
"role": "system",
|
48 |
+
"content": self._build_system_prompt(role)
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"role": "user",
|
52 |
+
"content": prompt
|
53 |
+
}
|
54 |
+
],
|
55 |
+
"temperature": 0.7,
|
56 |
+
"max_tokens": 2000
|
57 |
+
}
|
58 |
+
|
59 |
+
try:
|
60 |
+
response = session.post(
|
61 |
+
f"{self.base_url}/v1/chat/completions",
|
62 |
+
headers=self.headers,
|
63 |
+
json=payload,
|
64 |
+
timeout=(10, 60)
|
65 |
+
)
|
66 |
+
response.raise_for_status()
|
67 |
+
return response.json()['choices'][0]['message']['content']
|
68 |
+
except requests.exceptions.Timeout:
|
69 |
+
print("Deepseek API request timeout, retrying...")
|
70 |
+
return "Sorry, due to network issues, content generation is temporarily unavailable. Please try again later."
|
71 |
+
except requests.exceptions.RequestException as e:
|
72 |
+
print(f"Error calling Deepseek API: {str(e)}")
|
73 |
+
return "Sorry, an error occurred while generating content. Please try again later."
|
74 |
+
|
75 |
+
def summarize_document(self, content: str, title: str, url: str) -> str:
|
76 |
+
"""使用 Deepseek 总结文档"""
|
77 |
+
prompt = f"""Please analyze the following tourism web content and generate a rich summary paragraph.
|
78 |
+
|
79 |
+
Web Title: {title}
|
80 |
+
Web Link: {url}
|
81 |
+
|
82 |
+
Web Content:
|
83 |
+
{content[:4000]}
|
84 |
+
|
85 |
+
Requirements:
|
86 |
+
1. The summary should be between 300-500 words
|
87 |
+
2. Keep the most important tourism information (attractions, suggestions, tips, etc.)
|
88 |
+
3. Use an objective tone
|
89 |
+
4. Information should be accurate and practical
|
90 |
+
5. Remove marketing and advertising content
|
91 |
+
6. Maintain logical coherence
|
92 |
+
|
93 |
+
Please return the summary content directly, without any other explanation."""
|
94 |
+
|
95 |
+
return self.generate(prompt, role="summarizer")
|
96 |
+
|
97 |
+
def generate_travel_plan(self, query: str, context: List[Dict]) -> str:
|
98 |
+
# 构建更结构化的上下文
|
99 |
+
context_text = "\n\n".join([
|
100 |
+
f"Source {i+1} ({doc.get('title', 'Unknown Title')}):\n{doc['passage']}"
|
101 |
+
for i, doc in enumerate(context)
|
102 |
+
])
|
103 |
+
|
104 |
+
prompt = f"""As a professional travel planner, please create a detailed travel plan based on the user's needs and reference materials.
|
105 |
+
|
106 |
+
User Needs: {query}
|
107 |
+
|
108 |
+
Reference Materials:
|
109 |
+
{context_text}
|
110 |
+
|
111 |
+
Please provide the following content:
|
112 |
+
1. Itinerary Overview (Overall arrangement and key attractions)
|
113 |
+
2. Daily detailed itinerary (includes specific time, location, and transportation methods)
|
114 |
+
3. Traffic suggestions (includes practical APP recommendations)
|
115 |
+
4. Accommodation recommendations (includes specific areas and hotel suggestions)
|
116 |
+
5. Food recommendations (includes specialty restaurants and snacks)
|
117 |
+
6. Practical tips (weather, clothing, essential items, etc.)
|
118 |
+
|
119 |
+
Requirements:
|
120 |
+
1. The itinerary should be reasonable, considering the distance between attractions
|
121 |
+
2. Provide specific time points
|
122 |
+
3. Include detailed traffic guidance
|
123 |
+
4. Suggestions should be specific and practical
|
124 |
+
5. Consider actual conditions (e.g., opening hours of attractions)
|
125 |
+
|
126 |
+
Please return the travel plan content directly, without any other explanation."""
|
127 |
+
|
128 |
+
return self.generate(prompt, role="planner")
|
129 |
+
|
130 |
+
class OllamaInterface(LLMInterface):
|
131 |
+
def __init__(self, base_url: str, model: str):
|
132 |
+
self.base_url = base_url.rstrip('/')
|
133 |
+
self.model = model
|
134 |
+
self.headers = {
|
135 |
+
"Content-Type": "application/json"
|
136 |
+
}
|
137 |
+
|
138 |
+
def _build_system_prompt(self, role: str) -> str:
|
139 |
+
"""构建系统提示词"""
|
140 |
+
roles = {
|
141 |
+
"summarizer": "You are a professional tourism content analyst, good at extracting and summarizing key tourism-related information. Please answer in English",
|
142 |
+
"planner": "You are a professional travel planner who is good at making detailed travel plans. Please answer in English"
|
143 |
+
}
|
144 |
+
return roles.get(role, "You are a professional AI assistant. Please answer in English")
|
145 |
+
|
146 |
+
def generate(self, prompt: str, role: str = "planner") -> str:
|
147 |
+
import requests
|
148 |
+
|
149 |
+
payload = {
|
150 |
+
"model": self.model,
|
151 |
+
"messages": [
|
152 |
+
{
|
153 |
+
"role": "system",
|
154 |
+
"content": self._build_system_prompt(role)
|
155 |
+
},
|
156 |
+
{
|
157 |
+
"role": "user",
|
158 |
+
"content": prompt
|
159 |
+
}
|
160 |
+
],
|
161 |
+
"stream": False
|
162 |
+
}
|
163 |
+
|
164 |
+
try:
|
165 |
+
response = requests.post(
|
166 |
+
f"{self.base_url}/api/chat",
|
167 |
+
headers=self.headers,
|
168 |
+
json=payload,
|
169 |
+
timeout=(10, 60)
|
170 |
+
)
|
171 |
+
response.raise_for_status()
|
172 |
+
return response.json()['message']['content']
|
173 |
+
except Exception as e:
|
174 |
+
print(f"Error calling Ollama API: {str(e)}")
|
175 |
+
return "Sorry, an error occurred while generating content. Please try again later."
|
176 |
+
|
177 |
+
def summarize_document(self, content: str, title: str, url: str) -> str:
|
178 |
+
"""使用 Ollama 总结文档"""
|
179 |
+
prompt = f"""Please analyze the following tourism web content and generate a rich summary paragraph.
|
180 |
+
|
181 |
+
Web Title: {title}
|
182 |
+
Web Link: {url}
|
183 |
+
|
184 |
+
Web Content:
|
185 |
+
{content[:4000]}
|
186 |
+
|
187 |
+
Requirements:
|
188 |
+
1. The summary should be between 300-500 words
|
189 |
+
2. Keep the most important tourism information (attractions, suggestions, tips, etc.)
|
190 |
+
3. Use an objective tone
|
191 |
+
4. Information should be accurate and practical
|
192 |
+
5. Remove marketing and advertising content
|
193 |
+
6. Maintain logical coherence
|
194 |
+
|
195 |
+
Please return the summary content directly, without any other explanation."""
|
196 |
+
|
197 |
+
return self.generate(prompt, role="summarizer")
|
198 |
+
|
199 |
+
def generate_travel_plan(self, query: str, context: List[Dict]) -> str:
|
200 |
+
# 构建更结构化的上下文
|
201 |
+
context_text = "\n\n".join([
|
202 |
+
f"来源 {i+1} ({doc.get('title', '未知标题')}):\n{doc['passage']}"
|
203 |
+
for i, doc in enumerate(context)
|
204 |
+
])
|
205 |
+
|
206 |
+
prompt = f"""As a professional travel planner, please create a detailed travel plan based on the user's needs and reference materials.
|
207 |
+
|
208 |
+
User Needs: {query}
|
209 |
+
|
210 |
+
Reference Materials:
|
211 |
+
{context_text}
|
212 |
+
|
213 |
+
Please provide the following content:
|
214 |
+
1. Itinerary Overview (Overall arrangement and key attractions)
|
215 |
+
2. Daily detailed itinerary (includes specific time, location, and transportation methods)
|
216 |
+
3. Traffic suggestions (includes practical APP recommendations)
|
217 |
+
4. Accommodation recommendations (includes specific areas and hotel suggestions)
|
218 |
+
5. Food recommendations (includes specialty restaurants and snacks)
|
219 |
+
6. Practical tips (weather, clothing, essential items, etc.)
|
220 |
+
|
221 |
+
Requirements:
|
222 |
+
1. The itinerary should be reasonable, considering the distance between attractions
|
223 |
+
2. Provide specific time points
|
224 |
+
3. Include detailed traffic guidance
|
225 |
+
4. Suggestions should be specific and practical
|
226 |
+
5. Consider actual conditions (e.g., opening hours of attractions)
|
227 |
+
|
228 |
+
Please return the travel plan content directly, without any other explanation."""
|
229 |
+
|
230 |
+
return self.generate(prompt, role="planner")
|
231 |
+
|
232 |
+
class OpenAIInterface(LLMInterface):
|
233 |
+
def __init__(self, api_key: str, model: str = "gpt-4o", base_url: str = "https://api.feidaapi.com/v1"):
|
234 |
+
self.api_key = api_key
|
235 |
+
self.model = model
|
236 |
+
self.client = OpenAI(api_key=api_key, base_url=base_url)
|
237 |
+
|
238 |
+
def _build_system_prompt(self, role: str) -> str:
|
239 |
+
"""构建系统提示词"""
|
240 |
+
roles = {
|
241 |
+
"summarizer": "You are a professional tourism content analyst, good at extracting and summarizing key tourism-related information. Please answer in English",
|
242 |
+
"planner": "You are a professional travel planner who is good at making detailed travel plans. Please answer in English"
|
243 |
+
}
|
244 |
+
return roles.get(role, "You are a professional AI assistant. Please answer in English")
|
245 |
+
|
246 |
+
def generate(self, prompt: str, role: str = "planner") -> str:
|
247 |
+
try:
|
248 |
+
messages = [
|
249 |
+
{"role": "system", "content": self._build_system_prompt(role)},
|
250 |
+
{"role": "user", "content": prompt}
|
251 |
+
]
|
252 |
+
|
253 |
+
response = self.client.chat.completions.create(
|
254 |
+
model=self.model,
|
255 |
+
messages=messages,
|
256 |
+
temperature=0.7,
|
257 |
+
max_tokens=2000
|
258 |
+
)
|
259 |
+
|
260 |
+
return response.choices[0].message.content
|
261 |
+
|
262 |
+
except Exception as e:
|
263 |
+
logging.error(f"Error calling OpenAI API: {str(e)}")
|
264 |
+
return "Sorry, an error occurred while generating content. Please try again later."
|
265 |
+
|
266 |
+
def summarize_document(self, content: str, title: str, url: str) -> str:
|
267 |
+
"""使用 OpenAI 总结文档"""
|
268 |
+
prompt = f"""Please analyze the following tourism web content and generate a rich summary paragraph.
|
269 |
+
|
270 |
+
Web Title: {title}
|
271 |
+
Web Link: {url}
|
272 |
+
|
273 |
+
Web Content:
|
274 |
+
{content[:4000]}
|
275 |
+
|
276 |
+
Requirements:
|
277 |
+
1. The summary should be between 300-500 words
|
278 |
+
2. Keep the most important tourism information (attractions, suggestions, tips, etc.)
|
279 |
+
3. Use an objective tone
|
280 |
+
4. Information should be accurate and practical
|
281 |
+
5. Remove marketing and advertising content
|
282 |
+
6. Maintain logical coherence
|
283 |
+
|
284 |
+
Please return the summary content directly, without any other explanation."""
|
285 |
+
|
286 |
+
return self.generate(prompt, role="summarizer")
|
287 |
+
|
288 |
+
def generate_travel_plan(self, query: str, context: List[Dict]) -> str:
|
289 |
+
# 构建更结构化的上下文
|
290 |
+
context_text = "\n\n".join([
|
291 |
+
f"Source {i+1} ({doc.get('title', 'Unknown Title')}):\n{doc['passage']}"
|
292 |
+
for i, doc in enumerate(context)
|
293 |
+
])
|
294 |
+
|
295 |
+
prompt = f"""As a professional travel planner, please create a detailed travel plan based on the user's needs and reference materials.
|
296 |
+
|
297 |
+
User Needs: {query}
|
298 |
+
|
299 |
+
Reference Materials:
|
300 |
+
{context_text}
|
301 |
+
|
302 |
+
Please provide the following content:
|
303 |
+
1. Itinerary Overview (Overall arrangement and key attractions)
|
304 |
+
2. Daily detailed itinerary (includes specific time, location, and transportation methods)
|
305 |
+
3. Traffic suggestions (includes practical APP recommendations)
|
306 |
+
4. Accommodation recommendations (includes specific areas and hotel suggestions)
|
307 |
+
5. Food recommendations (includes specialty restaurants and snacks)
|
308 |
+
6. Practical tips (weather, clothing, essential items, etc.)
|
309 |
+
|
310 |
+
Requirements:
|
311 |
+
1. The itinerary should be reasonable, considering the distance between attractions
|
312 |
+
2. Provide specific time points
|
313 |
+
3. Include detailed traffic guidance
|
314 |
+
4. Suggestions should be specific and practical
|
315 |
+
5. Consider actual conditions (e.g., opening hours of attractions)
|
316 |
+
|
317 |
+
Please return the travel plan content directly, without any other explanation."""
|
318 |
+
|
319 |
+
return self.generate(prompt, role="planner")
|
src/api/ollama_api.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict
|
2 |
+
from .llm_api import LLMInterface
|
3 |
+
import requests
|
4 |
+
import json
|
5 |
+
|
6 |
+
class OllamaInterface(LLMInterface):
|
7 |
+
def __init__(self, model_name: str = "qwen2.5:32b_ctx32k", base_url: str = "http://localhost:11434"):
|
8 |
+
self.model_name = model_name
|
9 |
+
self.base_url = base_url
|
10 |
+
|
11 |
+
def generate(self, prompt: str) -> str:
|
12 |
+
"""使用 Ollama 生成响应"""
|
13 |
+
try:
|
14 |
+
response = requests.post(
|
15 |
+
f"{self.base_url}/api/generate",
|
16 |
+
json={
|
17 |
+
"model": self.model_name,
|
18 |
+
"prompt": prompt,
|
19 |
+
"stream": False,
|
20 |
+
"options": {
|
21 |
+
"temperature": 0.7,
|
22 |
+
"top_p": 0.9,
|
23 |
+
"top_k": 40,
|
24 |
+
}
|
25 |
+
}
|
26 |
+
)
|
27 |
+
response.raise_for_status()
|
28 |
+
return response.json()['response']
|
29 |
+
except Exception as e:
|
30 |
+
print(f"Ollama 生成错误: {e}")
|
31 |
+
return ""
|
32 |
+
|
33 |
+
def _build_system_prompt(self, role: str) -> str:
|
34 |
+
"""构建系统提示词"""
|
35 |
+
roles = {
|
36 |
+
"summarizer": "你是一个专业的旅游内容分析师,擅长提取和总结旅游相关的关键信息。",
|
37 |
+
"planner": "你是一个专业的旅行规划师,擅长制定详细的旅行计划。"
|
38 |
+
}
|
39 |
+
return roles.get(role, "你是一个专业的AI助手。")
|
40 |
+
|
41 |
+
def summarize_document(self, content: str, title: str, url: str) -> str:
|
42 |
+
"""使用 Qwen 2.5 总结文档"""
|
43 |
+
system_prompt = self._build_system_prompt("summarizer")
|
44 |
+
prompt = f"""{system_prompt}
|
45 |
+
|
46 |
+
请分析以下旅游网页内容,生成一个信息丰富的总结段落。
|
47 |
+
|
48 |
+
网页标题:{title}
|
49 |
+
网页链接:{url}
|
50 |
+
|
51 |
+
网页内容:
|
52 |
+
{content[:4000]}
|
53 |
+
|
54 |
+
要求:
|
55 |
+
1. 总结长度控制在300-500字
|
56 |
+
2. 保留最重要的旅游信息(景点、建议、提示等)
|
57 |
+
3. 使用客观的语气
|
58 |
+
4. 信息准确且实用
|
59 |
+
5. 去除营销和广告内容
|
60 |
+
6. 保持逻辑连贯性
|
61 |
+
|
62 |
+
请直接返回总结内容,不需要其他说明。"""
|
63 |
+
|
64 |
+
return self.generate(prompt)
|
65 |
+
|
66 |
+
def generate_travel_plan(self, query: str, context: List[Dict]) -> str:
|
67 |
+
"""使用 Qwen 2.5 生成旅行计划"""
|
68 |
+
system_prompt = self._build_system_prompt("planner")
|
69 |
+
|
70 |
+
context_text = "\n\n".join([
|
71 |
+
f"来源 {i+1}:\n{doc['passage']}"
|
72 |
+
for i, doc in enumerate(context)
|
73 |
+
])
|
74 |
+
|
75 |
+
prompt = f"""{system_prompt}
|
76 |
+
|
77 |
+
请根据以下信息,为用户制定一个详细的旅行计划。
|
78 |
+
|
79 |
+
用户需求:{query}
|
80 |
+
|
81 |
+
参考信息:
|
82 |
+
{context_text}
|
83 |
+
|
84 |
+
请提供以下内容:
|
85 |
+
1. 行程概览
|
86 |
+
2. 每日详细行程安排
|
87 |
+
3. 交通建议
|
88 |
+
4. 住宿推荐
|
89 |
+
5. 美食推荐
|
90 |
+
6. 注意事项和小贴士
|
91 |
+
|
92 |
+
要求:
|
93 |
+
1. 计划要详细且实用
|
94 |
+
2. 时间安排要合理
|
95 |
+
3. 建议要具体
|
96 |
+
4. 考虑实际情况(如交通时间、景点开放时间等)
|
97 |
+
5. 可以根据上下文补充合理的细节
|
98 |
+
|
99 |
+
请直接返回旅行计划内容,不需要其他说明。"""
|
100 |
+
|
101 |
+
return self.generate(prompt)
|
src/api/routes.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import APIRouter, HTTPException, Request
|
2 |
+
from typing import Dict, Any
|
3 |
+
from pydantic import BaseModel
|
4 |
+
from src.core.document_processor import DocumentProcessor
|
5 |
+
from src.core.ranking import RankingSystem
|
6 |
+
from src.core.plan_generator import PlanGenerator
|
7 |
+
from src.core.embeddings import EmbeddingModel
|
8 |
+
from src.core.reranker import Reranker
|
9 |
+
from src.api.search_api import GoogleSearch, BochaSearch
|
10 |
+
from src.utils.helpers import setup_proxy
|
11 |
+
from src.api.llm_api import DeepseekInterface
|
12 |
+
import logging
|
13 |
+
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
# 创建路由器
|
17 |
+
router = APIRouter(
|
18 |
+
tags=["travel"] # Swagger UI 的标签
|
19 |
+
)
|
20 |
+
|
21 |
+
# 添加根路由
|
22 |
+
@router.get("/")
|
23 |
+
async def root():
|
24 |
+
return {
|
25 |
+
"message": "欢迎使用旅游推荐系统 API",
|
26 |
+
"status": "运行正常",
|
27 |
+
"version": "1.0.0",
|
28 |
+
"endpoints": {
|
29 |
+
"健康检查": "/health",
|
30 |
+
"旅游推荐": "/api/v1/recommend",
|
31 |
+
"API文档": "/docs"
|
32 |
+
}
|
33 |
+
}
|
34 |
+
|
35 |
+
# 请求模型
|
36 |
+
class TravelQuery(BaseModel):
|
37 |
+
query: str
|
38 |
+
location: str = None
|
39 |
+
max_results: int = 10
|
40 |
+
|
41 |
+
# 响应模型
|
42 |
+
class TravelResponse(BaseModel):
|
43 |
+
recommendations: list
|
44 |
+
query: str
|
45 |
+
metadata: Dict[str, Any]
|
46 |
+
|
47 |
+
def init_app(app):
|
48 |
+
"""初始化应用"""
|
49 |
+
# 设置代理并获取代理状态
|
50 |
+
proxies, proxy_available = setup_proxy(app.state.config)
|
51 |
+
app.state.proxies = proxies
|
52 |
+
|
53 |
+
# 根据代理状态选择搜索引擎
|
54 |
+
if proxy_available:
|
55 |
+
app.state.search = GoogleSearch(
|
56 |
+
api_key=app.state.config['google_api_key'],
|
57 |
+
cx=app.state.config['google_cx'],
|
58 |
+
proxies=proxies
|
59 |
+
)
|
60 |
+
logging.info("使用 Google 搜索引擎")
|
61 |
+
else:
|
62 |
+
app.state.search = BochaSearch(
|
63 |
+
api_key=app.state.config['bocha_api_key'],
|
64 |
+
base_url=app.state.config['bocha_base_url']
|
65 |
+
)
|
66 |
+
logging.info("使用博查搜索引擎")
|
67 |
+
|
68 |
+
# 初始化 Deepseek LLM
|
69 |
+
app.state.llm = DeepseekInterface(
|
70 |
+
api_key=app.state.config['deepseek_api_key'],
|
71 |
+
base_url=app.state.config['llm_settings']['deepseek']['base_url'],
|
72 |
+
model=app.state.config['llm_settings']['deepseek']['models'][0]
|
73 |
+
)
|
74 |
+
|
75 |
+
# ... 其他初始化代码 ...
|
76 |
+
|
77 |
+
@router.post("/api/v1/recommend", response_model=TravelResponse)
|
78 |
+
async def get_travel_recommendations(query: TravelQuery, request: Request):
|
79 |
+
"""
|
80 |
+
获取旅游推荐
|
81 |
+
"""
|
82 |
+
logger.info(f"收到查询请求: {query.dict()}")
|
83 |
+
try:
|
84 |
+
# 使用已配置代理的搜索实例
|
85 |
+
search = request.app.state.search
|
86 |
+
llm = request.app.state.llm
|
87 |
+
|
88 |
+
# 执行搜索
|
89 |
+
logger.info("开始执行搜索...")
|
90 |
+
search_results = search.search(query.query)
|
91 |
+
logger.info(f"搜索完成,获得 {len(search_results)} 条结果")
|
92 |
+
|
93 |
+
# 处理文档
|
94 |
+
doc_processor = DocumentProcessor(llm)
|
95 |
+
passages = doc_processor.process_documents(search_results)
|
96 |
+
passages = [{'passage': p} for p in passages]
|
97 |
+
logging.info(f"Passages structure: {passages[:1]}") # 打印第一个元素的结构
|
98 |
+
|
99 |
+
# 初始化排序系统
|
100 |
+
embedding_model = EmbeddingModel("BAAI/bge-m3")
|
101 |
+
reranker = Reranker("BAAI/bge-reranker-large")
|
102 |
+
ranking_system = RankingSystem(embedding_model, reranker)
|
103 |
+
|
104 |
+
# 两阶段排序
|
105 |
+
initial_ranked = ranking_system.initial_ranking(
|
106 |
+
query.query,
|
107 |
+
passages,
|
108 |
+
10 # initial_top_k
|
109 |
+
)
|
110 |
+
|
111 |
+
final_ranked = ranking_system.rerank(
|
112 |
+
query.query,
|
113 |
+
initial_ranked,
|
114 |
+
3 # final_top_k
|
115 |
+
)
|
116 |
+
|
117 |
+
# 生成计划
|
118 |
+
plan_generator = PlanGenerator(llm)
|
119 |
+
final_plan = plan_generator.generate_plan(query.query, final_ranked)
|
120 |
+
|
121 |
+
return TravelResponse(
|
122 |
+
recommendations=[final_plan['plan']],
|
123 |
+
query=query.query,
|
124 |
+
metadata={
|
125 |
+
"location": query.location,
|
126 |
+
"max_results": query.max_results,
|
127 |
+
"sources": final_plan['sources']
|
128 |
+
}
|
129 |
+
)
|
130 |
+
|
131 |
+
except Exception as e:
|
132 |
+
logger.error(f"处理请求时发生错误: {str(e)}", exc_info=True)
|
133 |
+
raise HTTPException(status_code=500, detail=str(e))
|
134 |
+
|
135 |
+
# 健康检查端点
|
136 |
+
@router.get("/health")
|
137 |
+
async def health_check():
|
138 |
+
"""
|
139 |
+
API 健康检查端点
|
140 |
+
"""
|
141 |
+
return {"status": "healthy"}
|
src/api/search_api.py
ADDED
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict
|
2 |
+
import requests
|
3 |
+
from abc import ABC, abstractmethod
|
4 |
+
import logging
|
5 |
+
import json
|
6 |
+
|
7 |
+
logging.basicConfig(level=logging.DEBUG)
|
8 |
+
|
9 |
+
class SearchEngine(ABC):
|
10 |
+
def __init__(self):
|
11 |
+
# 广告域名黑名单
|
12 |
+
self.ad_domains = {
|
13 |
+
'ads.google.com',
|
14 |
+
'doubleclick.net',
|
15 |
+
'affiliate.',
|
16 |
+
'.ads.',
|
17 |
+
'promotion.',
|
18 |
+
'sponsored.',
|
19 |
+
'partner.',
|
20 |
+
'tracking.',
|
21 |
+
'.shop.',
|
22 |
+
'taobao.com',
|
23 |
+
'tmall.com',
|
24 |
+
'jd.com',
|
25 |
+
'mafengwo.cn', # 蚂蜂窝
|
26 |
+
'ctrip.com', # 携程
|
27 |
+
'tour.aoyou.com', # 同程
|
28 |
+
'wannar.com' # 玩哪儿
|
29 |
+
}
|
30 |
+
|
31 |
+
def is_ad_url(self, url: str) -> bool:
|
32 |
+
"""检查URL是否为广告链接"""
|
33 |
+
url_lower = url.lower()
|
34 |
+
return any(ad_domain in url_lower for ad_domain in self.ad_domains)
|
35 |
+
|
36 |
+
def enhance_query(self, query: str) -> str:
|
37 |
+
"""增强查询词,添加香港旅游关键词"""
|
38 |
+
if "Hong Kong" not in query:
|
39 |
+
query = f"Hong Kong Tourism{query}"
|
40 |
+
return query
|
41 |
+
|
42 |
+
@abstractmethod
|
43 |
+
def search(self, query: str) -> List[Dict]:
|
44 |
+
pass
|
45 |
+
|
46 |
+
class GoogleSearch(SearchEngine):
|
47 |
+
def __init__(self, api_key: str, cx: str, proxies: Dict[str, str] = None):
|
48 |
+
super().__init__()
|
49 |
+
self.api_key = api_key
|
50 |
+
self.cx = cx
|
51 |
+
self.base_url = "https://www.googleapis.com/customsearch/v1"
|
52 |
+
self.proxies = proxies or {}
|
53 |
+
|
54 |
+
def filter_results(self, results: List[Dict]) -> List[Dict]:
|
55 |
+
"""过滤搜索结果"""
|
56 |
+
filtered = []
|
57 |
+
for result in results:
|
58 |
+
url = result['url'].lower()
|
59 |
+
# 只过滤广告域名
|
60 |
+
if not self.is_ad_url(url):
|
61 |
+
filtered.append(result)
|
62 |
+
return filtered
|
63 |
+
|
64 |
+
def search(self, query: str) -> List[Dict]:
|
65 |
+
# 增强查询词
|
66 |
+
enhanced_query = self.enhance_query(query)
|
67 |
+
|
68 |
+
params = {
|
69 |
+
'key': self.api_key,
|
70 |
+
'cx': self.cx,
|
71 |
+
'q': enhanced_query
|
72 |
+
}
|
73 |
+
response = requests.get(self.base_url, params=params)
|
74 |
+
if response.status_code == 200:
|
75 |
+
results = response.json()
|
76 |
+
return [{
|
77 |
+
'title': item['title'],
|
78 |
+
'snippet': item['snippet'],
|
79 |
+
'url': item['link']
|
80 |
+
} for item in results.get('items', [])]
|
81 |
+
return []
|
82 |
+
|
83 |
+
class BochaSearch(SearchEngine):
|
84 |
+
def __init__(self, api_key: str, base_url: str, proxies: Dict[str, str] = None):
|
85 |
+
super().__init__()
|
86 |
+
self.api_key = api_key
|
87 |
+
self.base_url = base_url.rstrip('/') # 移除末尾可能的斜杠
|
88 |
+
self.proxies = proxies or {}
|
89 |
+
|
90 |
+
def search(self, query: str) -> List[Dict]:
|
91 |
+
try:
|
92 |
+
# 增强查询词
|
93 |
+
enhanced_query = self.enhance_query(query)
|
94 |
+
|
95 |
+
headers = {
|
96 |
+
'Authorization': f'Bearer {self.api_key}',
|
97 |
+
'Content-Type': 'application/json',
|
98 |
+
'Connection': 'keep-alive',
|
99 |
+
'Accept': '*/*'
|
100 |
+
}
|
101 |
+
|
102 |
+
payload = {
|
103 |
+
'query': enhanced_query,
|
104 |
+
'stream': False # 使用非流式返回
|
105 |
+
}
|
106 |
+
|
107 |
+
# 使用正确的端点
|
108 |
+
endpoint = f"{self.base_url}/v1/ai-search"
|
109 |
+
|
110 |
+
logging.info(f"正在请求博查API...")
|
111 |
+
logging.info(f"增强后的查询词: {enhanced_query}")
|
112 |
+
|
113 |
+
response = requests.post(
|
114 |
+
endpoint,
|
115 |
+
headers=headers,
|
116 |
+
json=payload,
|
117 |
+
proxies=None
|
118 |
+
)
|
119 |
+
|
120 |
+
# 详细打印响应信息
|
121 |
+
logging.info(f"API响应状态码: {response.status_code}")
|
122 |
+
logging.info(f"API响应内容: {response.text[:500]}...") # 只打印前500个字符
|
123 |
+
|
124 |
+
if response.status_code != 200:
|
125 |
+
logging.error(f"API请求失败,状态码: {response.status_code}")
|
126 |
+
logging.error(f"错误响应: {response.text}")
|
127 |
+
return []
|
128 |
+
|
129 |
+
response_json = response.json()
|
130 |
+
if response_json.get('code') == 200 and 'messages' in response_json:
|
131 |
+
messages = response_json['messages']
|
132 |
+
if messages and isinstance(messages, list):
|
133 |
+
for msg in messages:
|
134 |
+
if msg.get('type') == 'source' and msg.get('content_type') == 'webpage':
|
135 |
+
try:
|
136 |
+
content = json.loads(msg['content'])
|
137 |
+
if 'value' in content:
|
138 |
+
return content['value']
|
139 |
+
except json.JSONDecodeError:
|
140 |
+
logging.error(f"无法解析消息内容: {msg['content']}")
|
141 |
+
continue
|
142 |
+
|
143 |
+
logging.error(f"API返回数据结构异常: {response_json}")
|
144 |
+
return []
|
145 |
+
except Exception as e:
|
146 |
+
logging.error(f"处理API响应时出错: {str(e)}")
|
147 |
+
return []
|
148 |
+
|
149 |
+
def search_images(self, query: str, count: int = 3) -> List[Dict]:
|
150 |
+
"""搜索相关图片"""
|
151 |
+
try:
|
152 |
+
headers = {
|
153 |
+
'Authorization': f'Bearer {self.api_key}',
|
154 |
+
'Content-Type': 'application/json'
|
155 |
+
}
|
156 |
+
|
157 |
+
# 增强查询词
|
158 |
+
enhanced_query = self.enhance_query(query)
|
159 |
+
logging.info(f"增强后的图片搜索查询: {enhanced_query}")
|
160 |
+
|
161 |
+
payload = {
|
162 |
+
'query': enhanced_query,
|
163 |
+
'freshness': 'oneYear',
|
164 |
+
'count': 10, # 搜索更多图片以确保有足够的有效结果
|
165 |
+
'filter': 'images'
|
166 |
+
}
|
167 |
+
|
168 |
+
endpoint = f"{self.base_url}/v1/web-search"
|
169 |
+
|
170 |
+
response = requests.post(
|
171 |
+
endpoint,
|
172 |
+
headers=headers,
|
173 |
+
json=payload,
|
174 |
+
timeout=10
|
175 |
+
)
|
176 |
+
|
177 |
+
if response.status_code == 200:
|
178 |
+
try:
|
179 |
+
data = response.json()
|
180 |
+
logging.info(f"API返回数据结构: {data.keys()}")
|
181 |
+
|
182 |
+
if data.get('code') == 200 and 'data' in data:
|
183 |
+
data_content = data['data']
|
184 |
+
logging.info(f"data字段内容: {data_content.keys()}")
|
185 |
+
|
186 |
+
images = []
|
187 |
+
if 'images' in data_content:
|
188 |
+
image_items = data_content['images'].get('value', [])
|
189 |
+
logging.info(f"找到 {len(image_items)} 张图片")
|
190 |
+
|
191 |
+
for item in image_items:
|
192 |
+
# 简化过滤条件,只检查基本必要条件
|
193 |
+
if (item.get('contentUrl') and
|
194 |
+
item.get('width', 0) >= 300 and
|
195 |
+
item.get('height', 0) >= 300):
|
196 |
+
|
197 |
+
image_info = {
|
198 |
+
'url': item['contentUrl'],
|
199 |
+
'width': item['width'],
|
200 |
+
'height': item['height']
|
201 |
+
}
|
202 |
+
images.append(image_info)
|
203 |
+
if len(images) >= count:
|
204 |
+
break
|
205 |
+
|
206 |
+
logging.info(f"最终返回 {len(images)} 张图片")
|
207 |
+
return images[:count]
|
208 |
+
|
209 |
+
except json.JSONDecodeError as e:
|
210 |
+
logging.error(f"JSON解析错误: {str(e)}")
|
211 |
+
return []
|
212 |
+
except Exception as e:
|
213 |
+
logging.error(f"处理图片数据时出错: {str(e)}")
|
214 |
+
return []
|
215 |
+
|
216 |
+
logging.error(f"API请求失败,状态码: {response.status_code}")
|
217 |
+
return []
|
218 |
+
|
219 |
+
except Exception as e:
|
220 |
+
logging.error(f"图片搜索出错: {str(e)}")
|
221 |
+
return []
|
222 |
+
|
223 |
+
"""
|
224 |
+
class BingSearch(SearchEngine):
|
225 |
+
def __init__(self, api_key: str):
|
226 |
+
super().__init__()
|
227 |
+
self.api_key = api_key
|
228 |
+
self.base_url = "https://api.bing.microsoft.com/v7.0/search"
|
229 |
+
|
230 |
+
def search(self, query: str) -> List[Dict]:
|
231 |
+
# 只添加香港旅游关键词
|
232 |
+
enhanced_query = f"香港旅游 {query}"
|
233 |
+
|
234 |
+
headers = {'Ocp-Apim-Subscription-Key': self.api_key}
|
235 |
+
params = {
|
236 |
+
'q': enhanced_query
|
237 |
+
}
|
238 |
+
|
239 |
+
response = requests.get(
|
240 |
+
self.base_url,
|
241 |
+
headers=headers,
|
242 |
+
params=params
|
243 |
+
)
|
244 |
+
results = response.json()
|
245 |
+
|
246 |
+
filtered_results = []
|
247 |
+
for item in results.get('webPages', {}).get('value', []):
|
248 |
+
if not self.is_ad_url(item['url']):
|
249 |
+
filtered_results.append({
|
250 |
+
'title': item['name'],
|
251 |
+
'snippet': item['snippet'],
|
252 |
+
'url': item['url']
|
253 |
+
})
|
254 |
+
|
255 |
+
return filtered_results
|
256 |
+
|
257 |
+
def is_trusted_domain(self, url: str) -> bool:
|
258 |
+
""检查是否为可信域名""
|
259 |
+
return any(
|
260 |
+
trusted_domain in url.lower()
|
261 |
+
for trusted_domain in self.config['search_settings']['trusted_domains']
|
262 |
+
)
|
263 |
+
"""
|
src/core/__pycache__/document_processor.cpython-310.pyc
ADDED
Binary file (2.4 kB). View file
|
|
src/core/__pycache__/embeddings.cpython-310.pyc
ADDED
Binary file (1.77 kB). View file
|
|
src/core/__pycache__/html_processor.cpython-310.pyc
ADDED
Binary file (5.69 kB). View file
|
|
src/core/__pycache__/plan_generator.cpython-310.pyc
ADDED
Binary file (3.5 kB). View file
|
|
src/core/__pycache__/ranking.cpython-310.pyc
ADDED
Binary file (3.69 kB). View file
|
|
src/core/__pycache__/reranker.cpython-310.pyc
ADDED
Binary file (1.87 kB). View file
|
|
src/core/_init_.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Core package initialization.
|
3 |
+
This module contains the core functionality for the travel RAG system.
|
4 |
+
"""
|
5 |
+
|
6 |
+
from .document_processor import DocumentProcessor
|
7 |
+
from .ranking import RankingSystem
|
8 |
+
from .plan_generator import PlanGenerator
|
9 |
+
|
10 |
+
__all__ = [
|
11 |
+
'DocumentProcessor',
|
12 |
+
'RankingSystem',
|
13 |
+
'PlanGenerator'
|
14 |
+
]
|
src/core/document_processor.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict
|
2 |
+
from src.core.html_processor import HTMLProcessor
|
3 |
+
from src.api.llm_api import LLMInterface
|
4 |
+
import logging
|
5 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
6 |
+
|
7 |
+
class DocumentProcessor:
|
8 |
+
def __init__(self, llm: LLMInterface):
|
9 |
+
self.html_processor = HTMLProcessor()
|
10 |
+
self.llm = llm
|
11 |
+
# 添加缓存
|
12 |
+
self.cache = {}
|
13 |
+
|
14 |
+
def _clean_text(self, text: str) -> str:
|
15 |
+
"""清理文本内容"""
|
16 |
+
import re
|
17 |
+
# 移除多余空白
|
18 |
+
text = re.sub(r'\s+', ' ', text)
|
19 |
+
# 移除特殊字符
|
20 |
+
text = re.sub(r'[^\w\s\u4e00-\u9fff。,!?、]', '', text)
|
21 |
+
return text.strip()
|
22 |
+
|
23 |
+
def process_documents(self, search_results: List[Dict]) -> List[Dict]:
|
24 |
+
processed_docs = []
|
25 |
+
batch_size = 5 # 批处理大小
|
26 |
+
|
27 |
+
# 并行处理文档
|
28 |
+
with ThreadPoolExecutor(max_workers=5) as executor:
|
29 |
+
futures = []
|
30 |
+
for result in search_results:
|
31 |
+
if result['url'] in self.cache:
|
32 |
+
processed_docs.append(self.cache[result['url']])
|
33 |
+
continue
|
34 |
+
|
35 |
+
futures.append(
|
36 |
+
executor.submit(self._process_single_doc, result)
|
37 |
+
)
|
38 |
+
|
39 |
+
for future in as_completed(futures):
|
40 |
+
try:
|
41 |
+
doc = future.result()
|
42 |
+
if doc:
|
43 |
+
self.cache[doc['url']] = doc
|
44 |
+
processed_docs.append(doc)
|
45 |
+
except Exception as e:
|
46 |
+
logging.error(f"处理文档失败: {str(e)}")
|
47 |
+
|
48 |
+
return processed_docs[:5] # 限制返回数量
|
49 |
+
|
50 |
+
def _process_single_doc(self, result: Dict) -> Dict:
|
51 |
+
try:
|
52 |
+
# 提取HTML内容
|
53 |
+
html = self.html_processor.fetch_html(result['url'])
|
54 |
+
if not html:
|
55 |
+
return None
|
56 |
+
|
57 |
+
# 提取主要内容
|
58 |
+
content = self.html_processor.extract_main_content(html)
|
59 |
+
content = self._clean_text(content)
|
60 |
+
|
61 |
+
if len(content) < 100: # 内容太短
|
62 |
+
return None
|
63 |
+
|
64 |
+
# 生成更有针对性的总结
|
65 |
+
summary = self.llm.summarize_document(
|
66 |
+
content=content,
|
67 |
+
title=result.get('title', ''),
|
68 |
+
url=result['url']
|
69 |
+
)
|
70 |
+
|
71 |
+
if summary:
|
72 |
+
return {
|
73 |
+
'passage': summary,
|
74 |
+
'title': result.get('title', ''),
|
75 |
+
'url': result['url']
|
76 |
+
}
|
77 |
+
|
78 |
+
except Exception as e:
|
79 |
+
logging.error(f"处理文档失败 ({result['url']}): {str(e)}")
|
80 |
+
return None
|
src/core/embeddings.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sentence_transformers import SentenceTransformer
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import os
|
5 |
+
import logging
|
6 |
+
|
7 |
+
class EmbeddingModel:
|
8 |
+
def __init__(self, model_name="BAAI/bge-m3"):
|
9 |
+
try:
|
10 |
+
# 使用 Hugging Face 模型 ID
|
11 |
+
self.model = SentenceTransformer(model_name)
|
12 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
13 |
+
self.model.to(self.device)
|
14 |
+
logging.info(f"成功加载嵌入模型 {model_name} 到 {self.device} 设备")
|
15 |
+
except Exception as e:
|
16 |
+
logging.error(f"加载模型失败: {str(e)}")
|
17 |
+
raise
|
18 |
+
|
19 |
+
def encode(self, texts, batch_size=32):
|
20 |
+
"""
|
21 |
+
将文本转换为向量表示
|
22 |
+
"""
|
23 |
+
embeddings = self.model.encode(
|
24 |
+
texts,
|
25 |
+
batch_size=batch_size,
|
26 |
+
show_progress_bar=True,
|
27 |
+
normalize_embeddings=True
|
28 |
+
)
|
29 |
+
return embeddings
|
30 |
+
|
31 |
+
def encode_queries(self, queries):
|
32 |
+
"""
|
33 |
+
为查询文本添加特殊前缀并编码
|
34 |
+
BGE模型推荐在查询前添加"Represent this sentence for searching relevant passages: "
|
35 |
+
"""
|
36 |
+
prefix = "Represent this sentence for searching relevant passages: "
|
37 |
+
if isinstance(queries, str):
|
38 |
+
queries = [queries]
|
39 |
+
|
40 |
+
prefixed_queries = [prefix + query for query in queries]
|
41 |
+
return self.encode(prefixed_queries)
|
src/core/html_processor.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from bs4 import BeautifulSoup
|
3 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
4 |
+
from typing import List, Dict
|
5 |
+
import time
|
6 |
+
import logging
|
7 |
+
|
8 |
+
class HTMLProcessor:
|
9 |
+
def __init__(self, timeout: int = 5):
|
10 |
+
self.session = requests.Session()
|
11 |
+
self.timeout = timeout
|
12 |
+
|
13 |
+
def fetch_html(self, url: str) -> str:
|
14 |
+
"""获取单个URL的HTML内容"""
|
15 |
+
headers = {
|
16 |
+
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
|
17 |
+
'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8',
|
18 |
+
'Accept-Language': 'zh-CN,zh;q=0.9,en;q=0.8',
|
19 |
+
'Accept-Encoding': 'gzip, deflate, br',
|
20 |
+
'Connection': 'keep-alive'
|
21 |
+
}
|
22 |
+
|
23 |
+
try:
|
24 |
+
logging.info(f"开始获取URL: {url}")
|
25 |
+
response = self.session.get(
|
26 |
+
url,
|
27 |
+
timeout=self.timeout,
|
28 |
+
headers=headers,
|
29 |
+
verify=False
|
30 |
+
)
|
31 |
+
response.raise_for_status()
|
32 |
+
|
33 |
+
# 检查响应内容类型
|
34 |
+
content_type = response.headers.get('content-type', '')
|
35 |
+
if 'text/html' not in content_type.lower():
|
36 |
+
logging.warning(f"非HTML响应: {content_type}")
|
37 |
+
|
38 |
+
# 设置正确的编码
|
39 |
+
response.encoding = response.apparent_encoding
|
40 |
+
|
41 |
+
html = response.text
|
42 |
+
logging.info(f"成功获取HTML,长度: {len(html)}")
|
43 |
+
|
44 |
+
return html
|
45 |
+
|
46 |
+
except requests.Timeout:
|
47 |
+
logging.error(f"获取URL超时: {url}")
|
48 |
+
except requests.RequestException as e:
|
49 |
+
logging.error(f"获取URL失败 {url}: {str(e)}")
|
50 |
+
except Exception as e:
|
51 |
+
logging.error(f"未预期的错误 {url}: {str(e)}")
|
52 |
+
|
53 |
+
return ""
|
54 |
+
|
55 |
+
def fetch_multiple_html(self, urls: List[str], max_urls: int = 10) -> List[Dict]:
|
56 |
+
"""
|
57 |
+
并行获取多个URL的HTML内容
|
58 |
+
|
59 |
+
Args:
|
60 |
+
urls: URL列表
|
61 |
+
max_urls: 最大获取数量
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
List[Dict]: 包含成功获取的HTML内容列表
|
65 |
+
"""
|
66 |
+
results = []
|
67 |
+
urls = urls[:max_urls] # 只处理前max_urls个URL
|
68 |
+
|
69 |
+
with ThreadPoolExecutor(max_workers=max_urls) as executor:
|
70 |
+
# 提交所有任务
|
71 |
+
future_to_url = {
|
72 |
+
executor.submit(self.fetch_html, url): url
|
73 |
+
for url in urls
|
74 |
+
}
|
75 |
+
|
76 |
+
# 处理完成的任务
|
77 |
+
for future in as_completed(future_to_url):
|
78 |
+
url = future_to_url[future]
|
79 |
+
try:
|
80 |
+
html = future.result()
|
81 |
+
if html: # 只添加成功获取的结果
|
82 |
+
results.append({
|
83 |
+
'url': url,
|
84 |
+
'html': html,
|
85 |
+
'metadata': self.extract_metadata(html)
|
86 |
+
})
|
87 |
+
except Exception as e:
|
88 |
+
print(f"处理URL失败 {url}: {e}")
|
89 |
+
|
90 |
+
return results
|
91 |
+
|
92 |
+
def extract_main_content(self, html: str) -> str:
|
93 |
+
"""提取HTML中的主要内容"""
|
94 |
+
if not html:
|
95 |
+
logging.warning("输入的HTML为空")
|
96 |
+
return ""
|
97 |
+
|
98 |
+
try:
|
99 |
+
soup = BeautifulSoup(html, 'html.parser')
|
100 |
+
|
101 |
+
# 移除脚本和样式元素
|
102 |
+
for script in soup(["script", "style", "iframe", "nav", "footer", "header"]):
|
103 |
+
script.decompose()
|
104 |
+
|
105 |
+
# 记录原始长度
|
106 |
+
original_length = len(html)
|
107 |
+
|
108 |
+
# 尝试找到主要内容容器
|
109 |
+
main_content = None
|
110 |
+
possible_content_ids = ['content', 'main', 'article', 'post']
|
111 |
+
possible_content_classes = ['content', 'article', 'post', 'main-content']
|
112 |
+
|
113 |
+
# 按ID查找
|
114 |
+
for content_id in possible_content_ids:
|
115 |
+
main_content = soup.find(id=content_id)
|
116 |
+
if main_content:
|
117 |
+
break
|
118 |
+
|
119 |
+
# 按class查找
|
120 |
+
if not main_content:
|
121 |
+
for content_class in possible_content_classes:
|
122 |
+
main_content = soup.find(class_=content_class)
|
123 |
+
if main_content:
|
124 |
+
break
|
125 |
+
|
126 |
+
# 如果找不到特定容器,使用全文
|
127 |
+
text = main_content.get_text() if main_content else soup.get_text()
|
128 |
+
|
129 |
+
# 清理文本
|
130 |
+
lines = (line.strip() for line in text.splitlines())
|
131 |
+
chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
|
132 |
+
text = '\n'.join(chunk for chunk in chunks if chunk)
|
133 |
+
|
134 |
+
# 记���处理后长度
|
135 |
+
processed_length = len(text)
|
136 |
+
|
137 |
+
# 添加日志
|
138 |
+
logging.info(f"HTML处理前长度: {original_length}, 处理后长度: {processed_length}")
|
139 |
+
|
140 |
+
# 如果处理后文本太短,可能是提取失败
|
141 |
+
if processed_length < 100 and original_length > 1000:
|
142 |
+
logging.warning(f"提取的内容异常短: {processed_length} 字符")
|
143 |
+
return ""
|
144 |
+
|
145 |
+
return text
|
146 |
+
|
147 |
+
except Exception as e:
|
148 |
+
logging.error(f"提取主要内容时出错: {str(e)}")
|
149 |
+
return ""
|
150 |
+
|
151 |
+
def extract_metadata(self, html: str) -> dict:
|
152 |
+
"""提取HTML中的元数据"""
|
153 |
+
try:
|
154 |
+
soup = BeautifulSoup(html, 'html.parser')
|
155 |
+
metadata = {
|
156 |
+
'title': '',
|
157 |
+
'description': '',
|
158 |
+
'keywords': ''
|
159 |
+
}
|
160 |
+
|
161 |
+
# 更安全的标题提取
|
162 |
+
title = ''
|
163 |
+
if soup.title and soup.title.string:
|
164 |
+
title = soup.title.string.strip()
|
165 |
+
else:
|
166 |
+
# 尝试从h1标签提取标题
|
167 |
+
h1 = soup.find('h1')
|
168 |
+
if h1:
|
169 |
+
title = h1.get_text().strip()
|
170 |
+
|
171 |
+
# 如果还是没有标题,使用默认值
|
172 |
+
metadata['title'] = title if title else "未知标题"
|
173 |
+
|
174 |
+
# 提取meta描述
|
175 |
+
meta_desc = soup.find('meta', attrs={'name': ['description', 'Description']})
|
176 |
+
if meta_desc:
|
177 |
+
metadata['description'] = meta_desc.get('content', '').strip()
|
178 |
+
|
179 |
+
# 提取meta关键词
|
180 |
+
meta_keywords = soup.find('meta', attrs={'name': ['keywords', 'Keywords']})
|
181 |
+
if meta_keywords:
|
182 |
+
metadata['keywords'] = meta_keywords.get('content', '').strip()
|
183 |
+
|
184 |
+
# 确保所有字段都有值
|
185 |
+
metadata = {k: v if v else '未知' for k, v in metadata.items()}
|
186 |
+
|
187 |
+
return metadata
|
188 |
+
|
189 |
+
except Exception as e:
|
190 |
+
logging.error(f"提取元数据时出错: {str(e)}")
|
191 |
+
return {
|
192 |
+
'title': '未知标题',
|
193 |
+
'description': '未知描述',
|
194 |
+
'keywords': '未知关键词'
|
195 |
+
}
|
src/core/plan_generator.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict
|
2 |
+
from src.api.llm_api import LLMInterface
|
3 |
+
import logging
|
4 |
+
|
5 |
+
class PlanGenerator:
|
6 |
+
def __init__(self, llm: LLMInterface):
|
7 |
+
self.llm = llm
|
8 |
+
|
9 |
+
def generate_plan(self, query: str, context: List[Dict]) -> Dict:
|
10 |
+
"""生成旅行计划"""
|
11 |
+
# 确保查询包含香港关键词
|
12 |
+
if "Hong Kong" not in query:
|
13 |
+
query = f"Hong Kong Tourism {query}"
|
14 |
+
|
15 |
+
logging.info(f"Generating plan for query: {query}")
|
16 |
+
logging.info(f"Number of reference documents: {len(context)}")
|
17 |
+
|
18 |
+
# 构建提示词
|
19 |
+
prompt = self._build_prompt(query, context)
|
20 |
+
|
21 |
+
# 生成计划
|
22 |
+
plan = self.llm.generate_travel_plan(query, context)
|
23 |
+
|
24 |
+
# 记录来源
|
25 |
+
sources = []
|
26 |
+
for doc in context:
|
27 |
+
if doc.get('url'):
|
28 |
+
sources.append({
|
29 |
+
'url': doc['url'],
|
30 |
+
'title': doc.get('title', 'Unknown Title'),
|
31 |
+
'relevance_score': doc.get('relevance_score', 0)
|
32 |
+
})
|
33 |
+
|
34 |
+
return {
|
35 |
+
'query': query,
|
36 |
+
'plan': plan,
|
37 |
+
'sources': sources
|
38 |
+
}
|
39 |
+
|
40 |
+
def _build_prompt(self, query: str, context: List[Dict]) -> str:
|
41 |
+
"""构建提示词"""
|
42 |
+
# 提取查询中的关键信息
|
43 |
+
days = self._extract_days(query)
|
44 |
+
|
45 |
+
prompt = f"""Please create a detailed Hong Kong travel plan based on the following information.
|
46 |
+
|
47 |
+
User Needs: {query}
|
48 |
+
|
49 |
+
Reference Materials:
|
50 |
+
"""
|
51 |
+
# 添加上下文信息
|
52 |
+
for i, doc in enumerate(context, 1):
|
53 |
+
prompt += f"\nSource {i}:\n{doc['passage']}\n"
|
54 |
+
|
55 |
+
prompt += f"""
|
56 |
+
Please provide a detailed itinerary for a {days}-day trip to Hong Kong, including the following content:
|
57 |
+
|
58 |
+
1. Itinerary Overview:
|
59 |
+
- Overall itinerary arrangement
|
60 |
+
- Key attractions introduction
|
61 |
+
- Time allocation suggestions
|
62 |
+
|
63 |
+
2. Daily detailed itinerary:
|
64 |
+
- Morning activities and attractions
|
65 |
+
- Afternoon activities and attractions
|
66 |
+
- Evening activities and attractions
|
67 |
+
- Specific time allocation
|
68 |
+
- Traffic suggestions
|
69 |
+
|
70 |
+
3. Traffic Suggestions:
|
71 |
+
- Return transportation plan
|
72 |
+
- City transportation suggestions
|
73 |
+
- Traffic card purchase suggestions
|
74 |
+
- Practical traffic APP recommendations
|
75 |
+
|
76 |
+
4. Accommodation Recommendations:
|
77 |
+
- Recommended area
|
78 |
+
- Specific hotel suggestions
|
79 |
+
- Booking considerations
|
80 |
+
|
81 |
+
5. Food Recommendations:
|
82 |
+
- Specialty restaurants
|
83 |
+
- Recommended restaurants
|
84 |
+
- Snack street recommendations
|
85 |
+
|
86 |
+
6. Practical Tips:
|
87 |
+
- Weather suggestions
|
88 |
+
- Clothing suggestions
|
89 |
+
- Essential items
|
90 |
+
- Considerations
|
91 |
+
- Consumer budget
|
92 |
+
|
93 |
+
Please ensure:
|
94 |
+
1. The itinerary is reasonable, considering the distance between attractions and the time for visiting
|
95 |
+
2. Provide specific time allocation
|
96 |
+
3. Include practical local suggestions
|
97 |
+
4. Consider actual conditions (e.g., traffic time, attraction opening hours, etc.)
|
98 |
+
5. Provide detailed traffic guidance
|
99 |
+
|
100 |
+
Please return the travel plan content directly, without any other explanation."""
|
101 |
+
|
102 |
+
return prompt
|
103 |
+
|
104 |
+
def _extract_days(self, query: str) -> int:
|
105 |
+
"""从查询中提取天数"""
|
106 |
+
import re
|
107 |
+
# 匹配常见的天数表达方式
|
108 |
+
patterns = [
|
109 |
+
r'(\d+)\s*[天日]',
|
110 |
+
r'(\d+)\s*-*\s*days?',
|
111 |
+
r'(\d+)\s*-*\s*d'
|
112 |
+
]
|
113 |
+
|
114 |
+
for pattern in patterns:
|
115 |
+
match = re.search(pattern, query.lower())
|
116 |
+
if match:
|
117 |
+
return int(match.group(1))
|
118 |
+
|
119 |
+
# 默认返回3天
|
120 |
+
return 3
|
src/core/ranking.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import faiss
|
2 |
+
import numpy as np
|
3 |
+
from typing import List, Dict
|
4 |
+
from .embeddings import EmbeddingModel
|
5 |
+
from .reranker import Reranker
|
6 |
+
from sentence_transformers import SentenceTransformer
|
7 |
+
import logging
|
8 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
9 |
+
import torch
|
10 |
+
|
11 |
+
class RankingSystem:
|
12 |
+
def __init__(self,
|
13 |
+
embedding_model: EmbeddingModel = None,
|
14 |
+
reranker: Reranker = None):
|
15 |
+
self.embedding_model = embedding_model or EmbeddingModel()
|
16 |
+
self.reranker = reranker or Reranker()
|
17 |
+
self.index = None
|
18 |
+
self.passages = None
|
19 |
+
self.embedding_cache = {}
|
20 |
+
|
21 |
+
def build_index(self, passages: List[Dict]):
|
22 |
+
"""构建FAISS索引"""
|
23 |
+
self.passages = passages
|
24 |
+
texts = [p['passage'] for p in passages]
|
25 |
+
|
26 |
+
if not texts:
|
27 |
+
logging.warning("没有文本需要编码")
|
28 |
+
return
|
29 |
+
|
30 |
+
embeddings = self.embedding_model.encode(texts)
|
31 |
+
|
32 |
+
if embeddings is None or not hasattr(embeddings, 'shape'):
|
33 |
+
logging.error("编码结果为空或格式不正确")
|
34 |
+
return
|
35 |
+
|
36 |
+
dimension = embeddings.shape[1]
|
37 |
+
self.index = faiss.IndexFlatIP(dimension)
|
38 |
+
self.index.add(embeddings.astype('float32'))
|
39 |
+
|
40 |
+
def initial_ranking(self, query: str, passages: List[Dict], initial_top_k: int = 10) -> List[Dict]:
|
41 |
+
"""对文档进行初始排序并返回前K个结果"""
|
42 |
+
# 确保输入格式正确
|
43 |
+
if not isinstance(passages[0], dict):
|
44 |
+
passages = [{'passage': p} for p in passages]
|
45 |
+
|
46 |
+
# 使用缓存的嵌入
|
47 |
+
texts = [p['passage'] for p in passages]
|
48 |
+
embeddings = []
|
49 |
+
|
50 |
+
for text in texts:
|
51 |
+
if text in self.embedding_cache:
|
52 |
+
embeddings.append(self.embedding_cache[text])
|
53 |
+
else:
|
54 |
+
embedding = self.embedding_model.encode([text])[0]
|
55 |
+
self.embedding_cache[text] = embedding
|
56 |
+
embeddings.append(embedding)
|
57 |
+
|
58 |
+
embeddings = np.array(embeddings)
|
59 |
+
|
60 |
+
# 批量计算相似度
|
61 |
+
query_embedding = self.embedding_model.encode([query])[0]
|
62 |
+
similarities = np.dot(embeddings, query_embedding)
|
63 |
+
|
64 |
+
# 快速排序
|
65 |
+
indices = np.argsort(similarities)[::-1][:initial_top_k]
|
66 |
+
|
67 |
+
ranked_passages = []
|
68 |
+
for idx in indices:
|
69 |
+
passage = passages[idx].copy()
|
70 |
+
passage['retrieval_score'] = float(similarities[idx])
|
71 |
+
ranked_passages.append(passage)
|
72 |
+
|
73 |
+
return ranked_passages
|
74 |
+
|
75 |
+
def rerank(self, query: str, initial_ranked: List[Dict], final_top_k: int = 3) -> List[Dict]:
|
76 |
+
"""使用重排序器进行重排序"""
|
77 |
+
# 使用重排序器
|
78 |
+
reranked = self.reranker.rerank(query, initial_ranked)
|
79 |
+
|
80 |
+
# 计算最终分数(调整权重)
|
81 |
+
for passage in reranked:
|
82 |
+
# 增加相关性权重
|
83 |
+
passage['final_score'] = (
|
84 |
+
0.3 * passage['retrieval_score'] +
|
85 |
+
0.7 * passage['rerank_score']
|
86 |
+
)
|
87 |
+
|
88 |
+
# 按最终分数排序
|
89 |
+
final_ranked = sorted(
|
90 |
+
reranked,
|
91 |
+
key=lambda x: x['final_score'],
|
92 |
+
reverse=True
|
93 |
+
)
|
94 |
+
|
95 |
+
return final_ranked[:final_top_k]
|
96 |
+
|
97 |
+
def retrieve(self, query: str, passages: List[Dict]) -> List[Dict]:
|
98 |
+
"""
|
99 |
+
检索并排序文档
|
100 |
+
|
101 |
+
Args:
|
102 |
+
query: 查询字符串
|
103 |
+
passages: 待检索的文档列表
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
List[Dict]: 经过排序的文档列表
|
107 |
+
"""
|
108 |
+
# 1. 首先进行初始排序
|
109 |
+
initial_results = self.initial_ranking(query, passages, initial_top_k=10)
|
110 |
+
|
111 |
+
# 2. 然后进行重排序
|
112 |
+
final_results = self.rerank(query, initial_results, final_top_k=3)
|
113 |
+
|
114 |
+
return final_results
|
src/core/reranker.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Dict
|
2 |
+
from FlagEmbedding import FlagReranker
|
3 |
+
import logging
|
4 |
+
import torch
|
5 |
+
import os
|
6 |
+
from sentence_transformers import CrossEncoder
|
7 |
+
|
8 |
+
class Reranker:
|
9 |
+
def __init__(self, model_path="BAAI/bge-reranker-large"):
|
10 |
+
try:
|
11 |
+
self.model = FlagReranker(
|
12 |
+
model_path,
|
13 |
+
use_fp16=True,
|
14 |
+
device="cuda" if torch.cuda.is_available() else "cpu"
|
15 |
+
)
|
16 |
+
logging.info(f"成功加载重排序模型 {model_path} 到 {'cuda' if torch.cuda.is_available() else 'cpu'} 设备")
|
17 |
+
except Exception as e:
|
18 |
+
logging.error(f"加载重排序模型失败: {str(e)}")
|
19 |
+
raise
|
20 |
+
|
21 |
+
def rerank(self, query: str, passages: List[Dict]) -> List[Dict]:
|
22 |
+
"""
|
23 |
+
对文档进行重排序
|
24 |
+
"""
|
25 |
+
try:
|
26 |
+
# 准备文本列表
|
27 |
+
texts = [p['passage'] for p in passages]
|
28 |
+
|
29 |
+
# 执行重排序
|
30 |
+
scores = self.model.compute_score([[query, text] for text in texts])
|
31 |
+
|
32 |
+
# 将分数添加到原始字典中
|
33 |
+
for passage, score in zip(passages, scores):
|
34 |
+
passage['rerank_score'] = float(score)
|
35 |
+
|
36 |
+
# 按重排序分数排序
|
37 |
+
reranked = sorted(passages, key=lambda x: x['rerank_score'], reverse=True)
|
38 |
+
|
39 |
+
return reranked
|
40 |
+
|
41 |
+
except Exception as e:
|
42 |
+
logging.error(f"重排序过程中出错: {str(e)}")
|
43 |
+
# 如果重排序失败,返回原始排序
|
44 |
+
return passages
|
src/retrieval/__pycache__/base.cpython-310.pyc
ADDED
Binary file (842 Bytes). View file
|
|
src/retrieval/__pycache__/graph_rag.cpython-310.pyc
ADDED
Binary file (2.48 kB). View file
|
|
src/retrieval/__pycache__/memo_rag.cpython-310.pyc
ADDED
Binary file (2.62 kB). View file
|
|
src/retrieval/base.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import List, Dict
|
3 |
+
|
4 |
+
class BaseRetriever(ABC):
|
5 |
+
"""检索策略的基类"""
|
6 |
+
|
7 |
+
@abstractmethod
|
8 |
+
def retrieve(self, query: str, context: List[Dict]) -> List[Dict]:
|
9 |
+
"""执行检索"""
|
10 |
+
pass
|
11 |
+
|
12 |
+
@abstractmethod
|
13 |
+
def init_retriever(self, config: Dict):
|
14 |
+
"""初始化检索器"""
|
15 |
+
pass
|
src/retrieval/graph_rag.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base import BaseRetriever
|
2 |
+
from typing import Dict, List
|
3 |
+
import networkx as nx
|
4 |
+
# import spacy # 暂时注释掉
|
5 |
+
from sklearn.cluster import AgglomerativeClustering
|
6 |
+
import numpy as np
|
7 |
+
import logging
|
8 |
+
|
9 |
+
class GraphRAG(BaseRetriever):
|
10 |
+
def __init__(self, config: Dict):
|
11 |
+
self.config = config
|
12 |
+
self.graph = nx.Graph()
|
13 |
+
# self.nlp = spacy.load("zh_core_web_sm") # 暂时注释掉
|
14 |
+
self.init_retriever(config)
|
15 |
+
|
16 |
+
def init_retriever(self, config: Dict):
|
17 |
+
self.working_dir = config['retrieval_settings']['methods'][2]['model_settings']['working_dir']
|
18 |
+
self.graph_file = f"{self.working_dir}/graph.graphml"
|
19 |
+
|
20 |
+
def retrieve(self, query: str, context: List[Dict]) -> List[Dict]:
|
21 |
+
# 简单实现:基于关键词匹配的检索
|
22 |
+
scored_docs = []
|
23 |
+
for doc in context:
|
24 |
+
# 简单计算query中的词在文档中出现的次数作为分数
|
25 |
+
score = sum(1 for word in query.split() if word in doc['passage'])
|
26 |
+
doc_copy = doc.copy()
|
27 |
+
doc_copy['graph_score'] = float(score)
|
28 |
+
scored_docs.append(doc_copy)
|
29 |
+
|
30 |
+
return sorted(scored_docs, key=lambda x: x['graph_score'], reverse=True)
|
31 |
+
|
32 |
+
def _build_graph(self, context: List[Dict]):
|
33 |
+
"""简化版本的图构建"""
|
34 |
+
# 仅使用简单的词频统计
|
35 |
+
for doc in context:
|
36 |
+
text = doc['passage']
|
37 |
+
words = text.split()
|
38 |
+
# 相邻词之间建立边
|
39 |
+
for i in range(len(words)-1):
|
40 |
+
w1, w2 = words[i], words[i+1]
|
41 |
+
if not self.graph.has_edge(w1, w2):
|
42 |
+
self.graph.add_edge(w1, w2, weight=1)
|
43 |
+
else:
|
44 |
+
self.graph[w1][w2]['weight'] += 1
|
45 |
+
|
46 |
+
def _calculate_graph_score(self, query_words: List[str], doc: Dict) -> float:
|
47 |
+
"""简化版本的图分数计算"""
|
48 |
+
score = 0.0
|
49 |
+
doc_words = doc['passage'].split()
|
50 |
+
|
51 |
+
for q_word in query_words:
|
52 |
+
for d_word in doc_words:
|
53 |
+
if self.graph.has_edge(q_word, d_word):
|
54 |
+
score += self.graph[q_word][d_word]['weight']
|
55 |
+
|
56 |
+
return score if score > 0 else 0.0
|
src/retrieval/memo_rag.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .base import BaseRetriever
|
2 |
+
from typing import Dict, List
|
3 |
+
import torch
|
4 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
5 |
+
from sentence_transformers import SentenceTransformer
|
6 |
+
import logging
|
7 |
+
import os
|
8 |
+
|
9 |
+
class MemoRAG(BaseRetriever):
|
10 |
+
def __init__(self, config: Dict):
|
11 |
+
self.config = config
|
12 |
+
self.init_retriever(config)
|
13 |
+
|
14 |
+
def init_retriever(self, config: Dict):
|
15 |
+
memo_config = config['retrieval_settings']['methods'][1]['model_settings']
|
16 |
+
|
17 |
+
try:
|
18 |
+
# 使用本地量化模型
|
19 |
+
local_model_path = "/root/.cache/modelscope/hub/MaxLeton13/chatglm3-6B-32k-int4"
|
20 |
+
|
21 |
+
logging.info(f"加载本地模型: {local_model_path}")
|
22 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
23 |
+
local_model_path,
|
24 |
+
device_map="auto",
|
25 |
+
trust_remote_code=True
|
26 |
+
)
|
27 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
28 |
+
local_model_path,
|
29 |
+
trust_remote_code=True
|
30 |
+
)
|
31 |
+
|
32 |
+
# 初始化向量检索模型
|
33 |
+
logging.info(f"加载向量检索模型: {memo_config['ret_model']}")
|
34 |
+
self.embedding_model = SentenceTransformer(
|
35 |
+
memo_config['ret_model'],
|
36 |
+
device="cuda" if torch.cuda.is_available() else "cpu"
|
37 |
+
)
|
38 |
+
|
39 |
+
# 设置缓存目录
|
40 |
+
self.cache_dir = memo_config['cache_dir']
|
41 |
+
os.makedirs(self.cache_dir, exist_ok=True)
|
42 |
+
|
43 |
+
except Exception as e:
|
44 |
+
logging.error(f"初始化MemoRAG失败: {str(e)}")
|
45 |
+
raise
|
46 |
+
|
47 |
+
def retrieve(self, query: str, context: List[Dict]) -> List[Dict]:
|
48 |
+
try:
|
49 |
+
# 使用向量检索进行初步筛选
|
50 |
+
query_embedding = self.embedding_model.encode(query)
|
51 |
+
|
52 |
+
# 计算文档嵌入
|
53 |
+
docs_text = [doc['passage'] for doc in context]
|
54 |
+
docs_embeddings = self.embedding_model.encode(docs_text)
|
55 |
+
|
56 |
+
# 计算相似度
|
57 |
+
similarities = torch.nn.functional.cosine_similarity(
|
58 |
+
torch.tensor(query_embedding).unsqueeze(0),
|
59 |
+
torch.tensor(docs_embeddings),
|
60 |
+
dim=1
|
61 |
+
)
|
62 |
+
|
63 |
+
# 为每个文档添加分数
|
64 |
+
scored_docs = []
|
65 |
+
for doc, score in zip(context, similarities):
|
66 |
+
doc_copy = doc.copy()
|
67 |
+
doc_copy['memory_score'] = float(score)
|
68 |
+
scored_docs.append(doc_copy)
|
69 |
+
|
70 |
+
# 按分数排序
|
71 |
+
return sorted(scored_docs, key=lambda x: x['memory_score'], reverse=True)
|
72 |
+
|
73 |
+
except Exception as e:
|
74 |
+
logging.error(f"MemoRAG检索失败: {str(e)}")
|
75 |
+
# 如果检索失败,返回原始文档列表
|
76 |
+
return context
|
src/utils/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .helpers import load_config, setup_logging
|
2 |
+
|
3 |
+
__all__ = [
|
4 |
+
'load_config',
|
5 |
+
'setup_logging',
|
6 |
+
]
|
src/utils/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (246 Bytes). View file
|
|
src/utils/__pycache__/helpers.cpython-310.pyc
ADDED
Binary file (3 kB). View file
|
|
src/utils/helpers.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import yaml
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Dict, Any, Tuple
|
5 |
+
|
6 |
+
def load_config(config_path: str = "/root/travel_rag/config/config.yaml") -> Dict[str, Any]:
|
7 |
+
"""
|
8 |
+
从配置文件加载配置
|
9 |
+
|
10 |
+
Args:
|
11 |
+
config_path: 配置文件路径,默认为 "/root/travel_rag/config.yaml"
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
配置字典
|
15 |
+
"""
|
16 |
+
config_path = Path(config_path)
|
17 |
+
if not config_path.exists():
|
18 |
+
raise FileNotFoundError(f"配置文件未找到: {config_path}")
|
19 |
+
|
20 |
+
with open(config_path, 'r', encoding='utf-8') as f:
|
21 |
+
return yaml.safe_load(f)
|
22 |
+
|
23 |
+
def setup_logging(
|
24 |
+
log_level: str = "INFO",
|
25 |
+
log_file: str = None
|
26 |
+
) -> None:
|
27 |
+
"""
|
28 |
+
设置日志配置
|
29 |
+
|
30 |
+
Args:
|
31 |
+
log_level: 日志级别,默认为 "INFO"
|
32 |
+
log_file: 日志文件路径,默认为 None(仅控制台输出)
|
33 |
+
"""
|
34 |
+
# 设置日志格式
|
35 |
+
log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
36 |
+
|
37 |
+
# 配置根日志记录器
|
38 |
+
logging.basicConfig(
|
39 |
+
level=getattr(logging, log_level.upper()),
|
40 |
+
format=log_format
|
41 |
+
)
|
42 |
+
|
43 |
+
# 如果指定了日志文件,添加文件处理器
|
44 |
+
if log_file:
|
45 |
+
file_handler = logging.FileHandler(log_file)
|
46 |
+
file_handler.setFormatter(logging.Formatter(log_format))
|
47 |
+
logging.getLogger().addHandler(file_handler)
|
48 |
+
|
49 |
+
def setup_proxy(proxy_config_path: str = "/root/clash/config.yaml") -> Tuple[Dict[str, str], bool]:
|
50 |
+
"""
|
51 |
+
设置系统代理并返回代理配置和代理可用性状态
|
52 |
+
|
53 |
+
Args:
|
54 |
+
proxy_config_path: 代理配置文件路径(应为字符串类型)
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
Tuple[Dict[str, str], bool]: (代理配置字典, 代理是否可用)
|
58 |
+
"""
|
59 |
+
import os
|
60 |
+
import logging
|
61 |
+
import requests
|
62 |
+
from requests.exceptions import RequestException
|
63 |
+
|
64 |
+
logger = logging.getLogger(__name__)
|
65 |
+
|
66 |
+
# 设置默认代理地址
|
67 |
+
proxy_url = 'http://127.0.0.1:8880'
|
68 |
+
|
69 |
+
# 如果存在配置文件,从配置文件读取
|
70 |
+
if os.path.exists(proxy_config_path):
|
71 |
+
try:
|
72 |
+
config = load_config(proxy_config_path)
|
73 |
+
# 根据实际配置文件结构调整
|
74 |
+
proxy_url = config.get('proxy_url', proxy_url)
|
75 |
+
logger.info(f"已从配置文件加载代理设置: {proxy_url}")
|
76 |
+
except Exception as e:
|
77 |
+
logger.warning(f"加载代理配置失败: {e},使用默认配置")
|
78 |
+
|
79 |
+
# 设置环境变量
|
80 |
+
os.environ['HTTP_PROXY'] = proxy_url
|
81 |
+
os.environ['HTTPS_PROXY'] = proxy_url
|
82 |
+
|
83 |
+
proxies = {
|
84 |
+
'http': proxy_url,
|
85 |
+
'https': proxy_url
|
86 |
+
}
|
87 |
+
|
88 |
+
# 测试代理是否可用
|
89 |
+
try:
|
90 |
+
response = requests.get('https://www.google.com',
|
91 |
+
proxies=proxies,
|
92 |
+
timeout=5,
|
93 |
+
verify=False) # 添加 verify=False 避免证书问题
|
94 |
+
proxy_available = response.status_code == 200
|
95 |
+
if proxy_available:
|
96 |
+
logger.info("代理服务器可用")
|
97 |
+
else:
|
98 |
+
logger.warning(f"代理服务器响应异常,状态码: {response.status_code}")
|
99 |
+
except RequestException as e:
|
100 |
+
logger.warning(f"代理服务器连接失败: {e}")
|
101 |
+
proxy_available = False
|
102 |
+
|
103 |
+
logger.info(f"代理设置完成: {proxies}, 可用状态: {proxy_available}")
|
104 |
+
return proxies, proxy_available
|
src/utils/neo4j_helper.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from neo4j import GraphDatabase
|
2 |
+
from typing import List, Dict, Any
|
3 |
+
import logging
|
4 |
+
|
5 |
+
class Neo4jConnection:
|
6 |
+
def __init__(self, uri: str = "bolt://localhost:7687",
|
7 |
+
user: str = "neo4j",
|
8 |
+
password: str = "your_password"):
|
9 |
+
"""初始化 Neo4j 连接"""
|
10 |
+
try:
|
11 |
+
self.driver = GraphDatabase.driver(uri, auth=(user, password))
|
12 |
+
logging.info("Neo4j 连接成功")
|
13 |
+
except Exception as e:
|
14 |
+
logging.error(f"Neo4j 连接失败: {str(e)}")
|
15 |
+
raise
|
16 |
+
|
17 |
+
def close(self):
|
18 |
+
"""关闭连接"""
|
19 |
+
if self.driver:
|
20 |
+
self.driver.close()
|
21 |
+
|
22 |
+
def run_query(self, query: str, parameters: Dict[str, Any] = None) -> List[Dict]:
|
23 |
+
"""执行 Cypher 查询"""
|
24 |
+
try:
|
25 |
+
with self.driver.session() as session:
|
26 |
+
result = session.run(query, parameters or {})
|
27 |
+
return [record.data() for record in result]
|
28 |
+
except Exception as e:
|
29 |
+
logging.error(f"执行查询失败: {str(e)}")
|
30 |
+
raise
|
31 |
+
|
32 |
+
def create_node(self, label: str, properties: Dict[str, Any]) -> Dict:
|
33 |
+
"""创建节点"""
|
34 |
+
query = f"""
|
35 |
+
CREATE (n:{label} $properties)
|
36 |
+
RETURN n
|
37 |
+
"""
|
38 |
+
return self.run_query(query, {"properties": properties})
|
39 |
+
|
40 |
+
def create_relationship(self, start_node_label: str, start_node_props: Dict,
|
41 |
+
end_node_label: str, end_node_props: Dict,
|
42 |
+
relationship_type: str, relationship_props: Dict = None) -> Dict:
|
43 |
+
"""创建关系"""
|
44 |
+
query = f"""
|
45 |
+
MATCH (a:{start_node_label}), (b:{end_node_label})
|
46 |
+
WHERE a.id = $start_props.id AND b.id = $end_props.id
|
47 |
+
CREATE (a)-[r:{relationship_type} $rel_props]->(b)
|
48 |
+
RETURN a, r, b
|
49 |
+
"""
|
50 |
+
params = {
|
51 |
+
"start_props": start_node_props,
|
52 |
+
"end_props": end_node_props,
|
53 |
+
"rel_props": relationship_props or {}
|
54 |
+
}
|
55 |
+
return self.run_query(query, params)
|
56 |
+
|
57 |
+
def get_node(self, label: str, properties: Dict[str, Any]) -> Dict:
|
58 |
+
"""获取节点"""
|
59 |
+
query = f"""
|
60 |
+
MATCH (n:{label})
|
61 |
+
WHERE n.id = $properties.id
|
62 |
+
RETURN n
|
63 |
+
"""
|
64 |
+
return self.run_query(query, {"properties": properties})
|
65 |
+
|
66 |
+
# 使用示例:
|
67 |
+
if __name__ == "__main__":
|
68 |
+
# 创建连接
|
69 |
+
neo4j = Neo4jConnection()
|
70 |
+
|
71 |
+
try:
|
72 |
+
# 创建示例节点
|
73 |
+
node = neo4j.create_node("Place", {
|
74 |
+
"id": "1",
|
75 |
+
"name": "香港迪士尼乐园",
|
76 |
+
"type": "景点"
|
77 |
+
})
|
78 |
+
print("创建节点成功:", node)
|
79 |
+
|
80 |
+
finally:
|
81 |
+
# 关闭连接
|
82 |
+
neo4j.close()
|