zhuhai111 commited on
Commit
7cc8bc0
·
verified ·
1 Parent(s): 51879dd

Upload 43 files

Browse files
Files changed (43) hide show
  1. .gitattributes +1 -3
  2. .gradio/certificate.pem +31 -0
  3. .huggingface/YAML +9 -0
  4. README.md +40 -14
  5. app.py +799 -0
  6. config/config.yaml +72 -0
  7. requirements.txt +16 -0
  8. src/__init__.py +4 -0
  9. src/__pycache__/__init__.cpython-310.pyc +0 -0
  10. src/api/__init__.py +47 -0
  11. src/api/__pycache__/__init__.cpython-310.pyc +0 -0
  12. src/api/__pycache__/llm_api.cpython-310.pyc +0 -0
  13. src/api/__pycache__/ollama_api.cpython-310.pyc +0 -0
  14. src/api/__pycache__/routes.cpython-310.pyc +0 -0
  15. src/api/__pycache__/search_api.cpython-310.pyc +0 -0
  16. src/api/llm_api.py +319 -0
  17. src/api/ollama_api.py +101 -0
  18. src/api/routes.py +141 -0
  19. src/api/search_api.py +263 -0
  20. src/core/__pycache__/document_processor.cpython-310.pyc +0 -0
  21. src/core/__pycache__/embeddings.cpython-310.pyc +0 -0
  22. src/core/__pycache__/html_processor.cpython-310.pyc +0 -0
  23. src/core/__pycache__/plan_generator.cpython-310.pyc +0 -0
  24. src/core/__pycache__/ranking.cpython-310.pyc +0 -0
  25. src/core/__pycache__/reranker.cpython-310.pyc +0 -0
  26. src/core/_init_.py +14 -0
  27. src/core/document_processor.py +80 -0
  28. src/core/embeddings.py +41 -0
  29. src/core/html_processor.py +195 -0
  30. src/core/plan_generator.py +120 -0
  31. src/core/ranking.py +114 -0
  32. src/core/reranker.py +44 -0
  33. src/retrieval/__pycache__/base.cpython-310.pyc +0 -0
  34. src/retrieval/__pycache__/graph_rag.cpython-310.pyc +0 -0
  35. src/retrieval/__pycache__/memo_rag.cpython-310.pyc +0 -0
  36. src/retrieval/base.py +15 -0
  37. src/retrieval/graph_rag.py +56 -0
  38. src/retrieval/memo_rag.py +76 -0
  39. src/utils/__init__.py +6 -0
  40. src/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  41. src/utils/__pycache__/helpers.cpython-310.pyc +0 -0
  42. src/utils/helpers.py +104 -0
  43. src/utils/neo4j_helper.py +82 -0
.gitattributes CHANGED
@@ -23,13 +23,11 @@
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
  *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
  *.tflite filter=lfs diff=lfs merge=lfs -text
30
  *.tgz filter=lfs diff=lfs merge=lfs -text
31
  *.wasm filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
25
  *.safetensors filter=lfs diff=lfs merge=lfs -text
 
26
  *.tar.* filter=lfs diff=lfs merge=lfs -text
 
27
  *.tflite filter=lfs diff=lfs merge=lfs -text
28
  *.tgz filter=lfs diff=lfs merge=lfs -text
29
  *.wasm filter=lfs diff=lfs merge=lfs -text
30
  *.xz filter=lfs diff=lfs merge=lfs -text
31
  *.zip filter=lfs diff=lfs merge=lfs -text
32
  *.zst filter=lfs diff=lfs merge=lfs -text
33
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
.huggingface/YAML ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ title: Tourism Planning Assistant
2
+ emoji: 🌍
3
+ colorFrom: blue
4
+ colorTo: indigo
5
+ sdk: gradio
6
+ sdk_version: 4.8.0
7
+ app_file: app.py
8
+ pinned: true
9
+ license: mit
README.md CHANGED
@@ -1,14 +1,40 @@
1
- ---
2
- title: Toursim Test
3
- emoji: 📚
4
- colorFrom: blue
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 5.20.1
8
- app_file: app.py
9
- pinned: false
10
- license: lgpl-3.0
11
- short_description: Toursim-Test
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Tourism Planning Assistant
2
+
3
+ 旅游规划助手是一个基于 RAG (Retrieval-Augmented Generation) 技术的应用,可以帮助用户规划旅行行程。
4
+
5
+ ## 功能特点
6
+
7
+ - 支持基于用户需求生成个性化旅行计划
8
+ - 使用搜索引擎获取最新的旅游信息
9
+ - 支持多种语言模型(OpenAI, Deepseek)
10
+ - 自动检索和排序相关信息
11
+ - 支持相关图片展示
12
+
13
+ ## 部署到 Hugging Face Spaces
14
+
15
+ 该应用已优化,可以直接部署到 Hugging Face Spaces:
16
+
17
+ 1. 在 Hugging Face 创建一个新的 Space,选择 Gradio 作为 SDK
18
+ 2. 添加以下环境变量:
19
+ - `HF_BOCHA_API_KEY`: Bocha API 密钥
20
+ - `HF_OPENAI_API_KEY`: OpenAI API 密钥 (可选)
21
+ - `HF_DEEPSEEK_API_KEY`: Deepseek API 密钥 (可选)
22
+
23
+ ## 使用方法
24
+
25
+ 1. 输入您的旅行需求(例如:"香港迪士尼一日游")
26
+ 2. 选择旅行天数
27
+ 3. 点击"生成计划"按钮
28
+ 4. 获取个性化旅行计划,包含行程安排和相关参考链接
29
+
30
+ ## 技术栈
31
+
32
+ - Gradio: 用户界面
33
+ - Sentence Transformers: 文本嵌入
34
+ - FlagEmbedding: 文本重排序
35
+ - Hugging Face 模型: BGE-M3 和 BGE-Reranker-Large
36
+ - Bocha API: 搜索引擎接口
37
+
38
+ ## 许可证
39
+
40
+ MIT License
app.py ADDED
@@ -0,0 +1,799 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import sys
3
+ import os
4
+ from pathlib import Path
5
+ from typing import List
6
+ # 添加项目根目录到 Python 路径
7
+ sys.path.append(str(Path(__file__).parent))
8
+ from src.api.search_api import BochaSearch
9
+ from src.core.document_processor import DocumentProcessor
10
+ from src.core.ranking import RankingSystem
11
+ from src.core.plan_generator import PlanGenerator
12
+ from src.core.embeddings import EmbeddingModel
13
+ from src.core.reranker import Reranker
14
+ from src.api.llm_api import DeepseekInterface, LLMInterface, OpenAIInterface
15
+ from src.utils.helpers import load_config
16
+ import logging
17
+
18
+ # 设置日志
19
+ logging.basicConfig(level=logging.INFO)
20
+ logger = logging.getLogger(__name__)
21
+
22
+ class TravelRAGSystem:
23
+ def __init__(self):
24
+ self.config = load_config("config/config.yaml")
25
+ self.llm_instances = {} # 存储不同provider的LLM实例
26
+
27
+ # 固定使用 standard 检索方法
28
+ self.retrieval_method = "standard"
29
+
30
+ self.init_llm_instances()
31
+ self.init_components()
32
+
33
+ def init_components(self):
34
+ # 获取默认提供商的配置
35
+ default_provider = self.config['llm_settings']['default_provider']
36
+ provider_config = next(
37
+ (p for p in self.config['llm_settings']['providers']
38
+ if p['name'] == default_provider),
39
+ None
40
+ )
41
+
42
+ if not provider_config:
43
+ raise ValueError(f"未找到默认提供商 {default_provider} 的配置")
44
+
45
+ # 初始化LLM实例
46
+ self.llm = self.init_llm(provider_config['name'], provider_config['model'])
47
+
48
+ self.search_engine = BochaSearch(
49
+ api_key=self.config['bocha_api_key'],
50
+ base_url=self.config['bocha_base_url']
51
+ )
52
+
53
+ self.doc_processor = DocumentProcessor(self.llm)
54
+
55
+ # 初始化嵌入模型 - 使用 Hugging Face 模型 ID
56
+ try:
57
+ self.embedding_model = EmbeddingModel(
58
+ model_name="BAAI/bge-m3"
59
+ )
60
+ logger.info("成功加载嵌入模型")
61
+ except Exception as e:
62
+ logger.error(f"加载嵌入模型失败: {str(e)}")
63
+ raise
64
+
65
+ # 初始化重排序器 - 使用 Hugging Face 模型 ID
66
+ try:
67
+ self.reranker = Reranker(
68
+ model_path="BAAI/bge-reranker-large"
69
+ )
70
+ logger.info("成功加载重排序模型")
71
+ except Exception as e:
72
+ logger.error(f"加载重排序模型失败: {str(e)}")
73
+ raise
74
+
75
+ self.ranking_system = RankingSystem(self.embedding_model, self.reranker)
76
+ self.plan_generator = PlanGenerator(self.llm)
77
+
78
+ def init_llm(self, provider: str, model: str):
79
+ if provider == "openai":
80
+ return OpenAIInterface(
81
+ api_key=self.config['openai_api_key'],
82
+ model=model
83
+ )
84
+ elif provider == "deepseek":
85
+ return DeepseekInterface(
86
+ api_key=self.config['deepseek_api_key'],
87
+ base_url=next(
88
+ p['base_url'] for p in self.config['llm_settings']['providers']
89
+ if p['name'] == 'deepseek'
90
+ ),
91
+ model=model
92
+ )
93
+ else:
94
+ raise ValueError(f"不支持的LLM提供商: {provider}")
95
+
96
+ def init_llm_instances(self):
97
+ """初始化所有启用的LLM实例"""
98
+ for provider in self.config['llm_settings']['providers']:
99
+ if provider.get('enabled', False):
100
+ try:
101
+ if provider['name'] == "openai":
102
+ self.llm_instances['openai'] = OpenAIInterface(
103
+ api_key=self.config['openai_api_key'],
104
+ model=provider['model']
105
+ )
106
+ else:
107
+ self.llm_instances['deepseek'] = DeepseekInterface(
108
+ api_key=self.config['deepseek_api_key'],
109
+ base_url=provider['base_url'],
110
+ model=provider['model']
111
+ )
112
+ logging.info(f"成功初始化 {provider['name']} LLM")
113
+ except Exception as e:
114
+ logging.error(f"初始化 {provider['name']} LLM 失败: {str(e)}")
115
+
116
+ def get_llm(self, provider_name: str = None) -> LLMInterface:
117
+ """获取指定的LLM实例"""
118
+ if not provider_name:
119
+ provider_name = self.config['llm_settings']['default_provider']
120
+
121
+ if provider_name not in self.llm_instances:
122
+ raise ValueError(f"未找到或未启用的LLM提供商: {provider_name}")
123
+
124
+ return self.llm_instances[provider_name]
125
+
126
+ def process_query(
127
+ self,
128
+ query: str,
129
+ days: int,
130
+ llm_provider: str,
131
+ llm_model: str,
132
+ enable_images: bool = True,
133
+ retrieval_method: str = None
134
+ ) -> tuple:
135
+ try:
136
+ # 如果指定了新的检索方法,则切换
137
+ if retrieval_method and retrieval_method != self.retrieval_method:
138
+ self.set_retrieval_method(retrieval_method)
139
+
140
+ # 确保LLM提供商存在
141
+ if llm_provider not in self.llm_instances:
142
+ raise ValueError(f"LLM提供商 {llm_provider} 未启用或不可用,将使用默认提供商")
143
+ current_llm = self.llm_instances[llm_provider]
144
+ self.doc_processor = DocumentProcessor(current_llm)
145
+ self.plan_generator = PlanGenerator(current_llm)
146
+
147
+ # 确保查询包含天数
148
+ if days > 0:
149
+ query = f"{query} {days} days"
150
+
151
+ # 执行搜索
152
+ logger.info(f"执行搜索: {query}")
153
+ search_results = self.search_engine.search(query)
154
+ logger.info(f"搜索结果: {search_results}")
155
+
156
+ # 处理文档
157
+ passages = self.doc_processor.process_documents(search_results)
158
+ logger.info(f"处理后的文档: {passages}")
159
+
160
+ # 使用当前检索器进行检索和排序
161
+ if hasattr(self, 'retriever'):
162
+ final_ranked = self.retriever.retrieve(query, passages)
163
+ else:
164
+ # 使用默认的排序系统
165
+ initial_ranked = self.ranking_system.initial_ranking(query, passages)
166
+ final_ranked = self.ranking_system.rerank(query, initial_ranked)
167
+
168
+ # 生成计划
169
+ final_plan = self.plan_generator.generate_plan(query, final_ranked)
170
+ logger.info(f"生成的计划: {final_plan}")
171
+
172
+ # 修改准备参考来源的部分
173
+ # 创建表格的表头
174
+ table_header = "| Reference URL | Relevance Score | Retrieval Score | Rerank Score |\n| --- | --- | --- | --- |"
175
+
176
+ # 准备表格行
177
+ table_rows = []
178
+ for doc in final_ranked:
179
+ # 如果标题为空,使用URL作为标题
180
+ title = doc.get('title', '').strip()
181
+ if not title:
182
+ from urllib.parse import urlparse
183
+ domain = urlparse(doc['url']).netloc
184
+ title = domain
185
+
186
+ # 创建表格行
187
+ row = (
188
+ f"| [{title}]({doc['url']}) | "
189
+ f"{doc.get('final_score', 0):.3f} | "
190
+ f"{doc.get('retrieval_score', 0):.3f} | "
191
+ f"{doc.get('rerank_score', 0):.3f} |"
192
+ )
193
+ table_rows.append(row)
194
+
195
+ # 组合表格
196
+ sources = table_header + "\n" + "\n".join(table_rows)
197
+
198
+ logger.info(f"参考来源: {sources}")
199
+
200
+ # 修改图片展示部分
201
+ image_html = ""
202
+ if enable_images:
203
+ try:
204
+ # 增加搜索数量,因为要过滤
205
+ images = self.search_engine.search_images(query, count=8)
206
+ valid_images = []
207
+
208
+ if images:
209
+ # 过滤图片
210
+ for img in images:
211
+ img_url = img.get('url', '')
212
+ if img_url and self.verify_image_url(img_url):
213
+ valid_images.append(img_url)
214
+ if len(valid_images) >= 3: # 只需要3张有效图片
215
+ break
216
+
217
+ if valid_images: # 如果有有效图片
218
+ image_html = """
219
+ <div style="display: flex; flex-direction: column; gap: 15px;">
220
+ """
221
+ for img_url in valid_images:
222
+ image_html += f"""
223
+ <div style="border-radius: 8px; overflow: hidden;
224
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.3);">
225
+ <div style="position: relative; padding-top: 66.67%;">
226
+ <img src="{img_url}"
227
+ alt="旅行相关图片"
228
+ style="position: absolute; top: 0; left: 0; width: 100%;
229
+ height: 100%; object-fit: cover; transition: transform 0.3s;"
230
+ onerror="this.style.display='none'">
231
+ </div>
232
+ </div>
233
+ """
234
+ image_html += "</div>"
235
+ except Exception as e:
236
+ logger.warning(f"获取图片时出现错误: {str(e)}")
237
+
238
+ # 美化计划内容的展示
239
+ plan_content = final_plan['plan']
240
+
241
+ # 替换标记符号
242
+ replacements = {
243
+ '###': '', # 移除三个#
244
+ '##': '', # 移除两个#
245
+ '# ': '', # 移除单个#
246
+ '**': '', # 移除所有**
247
+ }
248
+
249
+ for old, new in replacements.items():
250
+ plan_content = plan_content.replace(old, new)
251
+
252
+ # 处理标题和段落
253
+ paragraphs = plan_content.split('\n')
254
+ formatted_paragraphs = []
255
+
256
+ for p in paragraphs:
257
+ p = p.strip()
258
+ if not p:
259
+ continue
260
+
261
+ if "Tour Overview" in p:
262
+ # 主标题样式
263
+ formatted_paragraphs.append(
264
+ f'<h2 style="color: #f3f4f6; margin: 10px 0 12px 0; font-size: 1.15em; '
265
+ f'font-weight: 600; letter-spacing: 0.01em; line-height: 1.3; '
266
+ f'border-bottom: 1px solid rgba(99, 102, 241, 0.3); padding-bottom: 8px; '
267
+ f'font-family: -apple-system, BlinkMacSystemFont, Segoe UI, Roboto, Helvetica Neue, sans-serif;">'
268
+ f'📍 {p}</h2>'
269
+ )
270
+ elif ":" in p and any(x in p for x in ["Date", "Destination", "Key Attractions"]):
271
+ # 关键信息样式,加粗值部分
272
+ key, value = p.split(":", 1)
273
+ formatted_paragraphs.append(
274
+ f'<div style="display: flex; align-items: start; margin: 4px 0; '
275
+ f'padding-left: 4px;">'
276
+ f'<span style="color: #818cf8; font-weight: 500; min-width: 70px; '
277
+ f'font-size: 0.92em;">{key}:</span>'
278
+ f'<span style="flex: 1; line-height: 1.4; color: #e2e8f0; margin-left: 8px; '
279
+ f'font-family: -apple-system, BlinkMacSystemFont, Segoe UI, Roboto, Helvetica Neue, sans-serif; '
280
+ f'font-size: 0.92em; font-weight: 600; letter-spacing: 0.005em;">{value.strip()}</span>'
281
+ f'</div>'
282
+ )
283
+ elif "Daily Itinerary" in p:
284
+ # 主标题样式
285
+ formatted_paragraphs.append(
286
+ f'<h2 style="color: #f3f4f6; margin: 20px 0 12px 0; font-size: 1.15em; '
287
+ f'font-weight: 600; letter-spacing: 0.01em; line-height: 1.3; '
288
+ f'border-bottom: 1px solid rgba(99, 102, 241, 0.3); padding-bottom: 8px; '
289
+ f'font-family: -apple-system, BlinkMacSystemFont, Segoe UI, Roboto, Helvetica Neue, sans-serif;">'
290
+ f'🕒️ {p}</h2>'
291
+ )
292
+ elif " - " in p: # 时间段标题
293
+ # 子标题样式
294
+ formatted_paragraphs.append(
295
+ f'<h3 style="color: #e2e8f0; margin: 14px 0 6px 0; font-size: 1.05em; '
296
+ f'font-weight: 600; letter-spacing: 0.01em; line-height: 1.4; padding-bottom: 4px; '
297
+ f'font-family: -apple-system, BlinkMacSystemFont, Segoe UI, Roboto, Helvetica Neue, sans-serif;">'
298
+ f'🕒 {p}</h3>'
299
+ )
300
+ elif p.startswith("Location") or p.startswith("Activity") or p.startswith("Transportation") or p.startswith("Specific Guidance"):
301
+ # 信息样式
302
+ key, value = p.split(":", 1)
303
+ icon = {
304
+ "Location": "📍",
305
+ "Activity": "🎯",
306
+ "Transportation": "🚇",
307
+ "Specific Guidance": "🗺️"
308
+ }.get(key, "•")
309
+
310
+ formatted_paragraphs.append(
311
+ f'<div style="display: flex; align-items: start; margin: 4px 0; padding-left: 4px;">'
312
+ f'<span style="color: #818cf8; margin-right: 8px;">{icon}</span>'
313
+ f'<span style="flex: 1; line-height: 1.4; color: #e2e8f0; '
314
+ f'font-family: -apple-system, BlinkMacSystemFont, Segoe UI, Roboto, Helvetica Neue, sans-serif; '
315
+ f'font-size: 0.92em; font-weight: 400; letter-spacing: 0.005em;">{value.strip()}</span>'
316
+ f'</div>'
317
+ )
318
+ else:
319
+ # 普通段落样式
320
+ formatted_paragraphs.append(
321
+ f'<p style="margin: 8px 0; line-height: 1.5; color: #e2e8f0; '
322
+ f'font-family: -apple-system, BlinkMacSystemFont, Segoe UI, Roboto, Helvetica Neue, sans-serif; '
323
+ f'font-size: 0.95em; font-weight: 400; letter-spacing: 0.005em;">{p}</p>'
324
+ )
325
+
326
+ plan_content = '\n'.join(formatted_paragraphs)
327
+
328
+ # 将所有内容包装在一个暗色主题的容器中
329
+ final_output = f"""
330
+ <div style="max-width: 100%; padding: 24px; background: rgba(17, 24, 39, 0.7);
331
+ border-radius: 16px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.2);">
332
+ <div style="margin-top: 20px;">
333
+ {plan_content}
334
+ </div>
335
+ </div>
336
+ """
337
+
338
+ return final_output, sources, image_html
339
+
340
+ except Exception as e:
341
+ logger.error(f"Error processing query: {str(e)}")
342
+ return f"Sorry, an error occurred while processing your request: {str(e)}", "", ""
343
+
344
+ def verify_image_url(self, url: str) -> bool:
345
+ """验证图片URL是否可访问且符合要求"""
346
+ try:
347
+ import requests
348
+ from PIL import Image
349
+ import io
350
+ import numpy as np
351
+ from PIL import ImageDraw, ImageFont
352
+
353
+ # 获取图片
354
+ response = requests.get(url, timeout=3)
355
+ if response.status_code != 200:
356
+ return False
357
+
358
+ # 检查内容类型
359
+ content_type = response.headers.get('content-type', '')
360
+ if 'image' not in content_type.lower():
361
+ return False
362
+
363
+ # 读取图片
364
+ img = Image.open(io.BytesIO(response.content))
365
+
366
+ # 1. 检查图片尺寸
367
+ width, height = img.size
368
+ if width < 300 or height < 300: # 过滤掉太小的图片
369
+ return False
370
+
371
+ # 2. 检查宽高比
372
+ aspect_ratio = width / height
373
+ if aspect_ratio < 0.5 or aspect_ratio > 2.0: # 过滤掉比例不合适的图片
374
+ return False
375
+
376
+ # 3. 转换为numpy数组进行分析
377
+ img_array = np.array(img)
378
+
379
+ # 4. 检查图片是否过于单调(可能是纯文字图)
380
+ if len(img_array.shape) == 3: # 确保是彩色图片
381
+ std = np.std(img_array)
382
+ if std < 30: # 标准差太小说明图片太单调
383
+ return False
384
+
385
+ # 5. 检测文字区域(简单实现)
386
+ # 转换为灰度图
387
+ if img.mode != 'L':
388
+ img_gray = img.convert('L')
389
+ else:
390
+ img_gray = img
391
+
392
+ # 计算边缘密度
393
+ from PIL import ImageFilter
394
+ edges = img_gray.filter(ImageFilter.FIND_EDGES)
395
+ edge_density = np.mean(np.array(edges))
396
+
397
+ # 如果边缘密度太高,可能包含大量文字
398
+ if edge_density > 30:
399
+ return False
400
+
401
+ # 6. 检查图片是否过于饱和(可能是广告图)
402
+ if len(img_array.shape) == 3:
403
+ hsv = img.convert('HSV')
404
+ saturation = np.array(hsv)[:,:,1]
405
+ if np.mean(saturation) > 200: # 饱和度过高
406
+ return False
407
+
408
+ return True
409
+
410
+ except Exception as e:
411
+ logger.warning(f"图片验证失败: {str(e)}")
412
+ return False
413
+
414
+ def _format_images_html(self, images: List[str]) -> str:
415
+ """格式化图片HTML展示"""
416
+ if not images:
417
+ return ""
418
+
419
+ # 使用flex布局来展示图片
420
+ html = """
421
+ <div style="display: flex; flex-wrap: wrap; gap: 10px; justify-content: center; margin-top: 20px;">
422
+ """
423
+
424
+ for img_url in images:
425
+ # 添加图片容器和加载失败处理
426
+ html += f"""
427
+ <div style="flex: 0 0 calc(50% - 10px); max-width: 300px; min-width: 200px;">
428
+ <img
429
+ src="{img_url}"
430
+ style="width: 100%; height: 200px; object-fit: cover; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);"
431
+ onerror="this.onerror=null; this.src='https://via.placeholder.com/300x200?text=Image+Not+Available';"
432
+ />
433
+ </div>
434
+ """
435
+
436
+ html += "</div>"
437
+
438
+ # 添加调试日志
439
+ logger.info(f"生成的图片HTML: {html[:200]}...") # 只打印前200个字符
440
+
441
+ return html
442
+
443
+ def set_retrieval_method(self, method: str):
444
+ """切换检索方法"""
445
+ if method not in ["standard"]:
446
+ raise ValueError(f"不支持的检索方法: {method}")
447
+
448
+ self.retrieval_method = method
449
+
450
+ # 根据方法初始化对应的检索器
451
+ if method == "standard":
452
+ self.retriever = self.ranking_system
453
+
454
+ def create_interface():
455
+ system = TravelRAGSystem()
456
+
457
+ # 获取已启用的提供商列表
458
+ enabled_providers = [
459
+ provider['name']
460
+ for provider in system.config['llm_settings']['providers']
461
+ if provider['enabled']
462
+ ]
463
+
464
+ # 创建提供商和模型的映射
465
+ provider_models = {
466
+ provider['name']: provider['models']
467
+ for provider in system.config['llm_settings']['providers']
468
+ if provider['enabled']
469
+ }
470
+
471
+ # 创建界面并设置自定义CSS
472
+ css = """
473
+ .gradio-container {
474
+ font-family: "PingFang SC", "Microsoft YaHei", sans-serif;
475
+ }
476
+
477
+ /* 针对所有英文文本 */
478
+ [class*="message-"] {
479
+ font-family: 'Times New Roman', serif !important;
480
+ }
481
+
482
+ /* 确保英文和数字使用 Times New Roman */
483
+ .gradio-container *:not(:lang(zh)) {
484
+ font-family: 'Times New Roman', serif !important;
485
+ }
486
+
487
+ @keyframes spin {
488
+ 0% { transform: rotate(0deg); }
489
+ 100% { transform: rotate(360deg); }
490
+ }
491
+
492
+ /* 隐藏数字输入框的上下箭头 */
493
+ input[type="number"]::-webkit-inner-spin-button,
494
+ input[type="number"]::-webkit-outer-spin-button {
495
+ -webkit-appearance: none;
496
+ margin: 0;
497
+ }
498
+
499
+ input[type="number"] {
500
+ -moz-appearance: textfield;
501
+ }
502
+
503
+ /* 隐藏默认的 processing 信息和箭头 */
504
+ .progress-text, .meta-text-center, .progress-container {
505
+ display: none !important;
506
+ }
507
+
508
+ /* 修改加载动画样式 */
509
+ .loading {
510
+ display: flex;
511
+ align-items: center;
512
+ justify-content: center;
513
+ gap: 8px;
514
+ font-size: 1.2em;
515
+ color: rgb(192, 192, 255);
516
+ }
517
+
518
+ .loading::before {
519
+ content: '🌍';
520
+ display: inline-block;
521
+ animation: spin 2s linear infinite;
522
+ filter: brightness(1.5); /* 让地球图标更亮 */
523
+ }
524
+
525
+ /* 调整 Gradio 默认加载动画的位置 */
526
+ .progress-text {
527
+ display: block !important;
528
+ order: 3;
529
+ margin-top: 8px;
530
+ opacity: 0.7;
531
+ }
532
+
533
+ .meta-text-center {
534
+ display: block !important;
535
+ }
536
+
537
+ /* 确保加载容器使用 flex 布局 */
538
+ .loading-container {
539
+ display: flex;
540
+ flex-direction: column;
541
+ align-items: center;
542
+ }
543
+
544
+ /* 隐藏滑块右侧的上下箭头 */
545
+ .num-input-plus, .num-input-minus {
546
+ display: none !important;
547
+ }
548
+
549
+ /* 隐藏所有滚动箭头 */
550
+ .scroll-hide,
551
+ .output-markdown,
552
+ .output-text,
553
+ .markdown-text,
554
+ .prose,
555
+ .gr-box,
556
+ .gr-panel {
557
+ -ms-overflow-style: none !important;
558
+ scrollbar-width: none !important;
559
+ overflow-y: hidden !important;
560
+ overflow: hidden !important;
561
+ }
562
+
563
+ .scroll-hide::-webkit-scrollbar,
564
+ .output-markdown::-webkit-scrollbar,
565
+ .output-text::-webkit-scrollbar,
566
+ .markdown-text::-webkit-scrollbar,
567
+ .prose::-webkit-scrollbar,
568
+ .gr-box::-webkit-scrollbar,
569
+ .gr-panel::-webkit-scrollbar {
570
+ display: none !important;
571
+ width: 0 !important;
572
+ height: 0 !important;
573
+ }
574
+
575
+ /* 修改加载动画容器样式 */
576
+ .loading-container {
577
+ overflow: hidden !important;
578
+ min-height: 60px;
579
+ }
580
+
581
+ /* 隐藏 Gradio 默认的滚动控件 */
582
+ .wrap.svelte-byatnx,
583
+ .contain.svelte-byatnx,
584
+ [class*='svelte'],
585
+ .gradio-container {
586
+ overflow: hidden !important;
587
+ overflow-y: hidden !important;
588
+ }
589
+
590
+ /* 禁用所有可能的滚动控件 */
591
+ ::-webkit-scrollbar {
592
+ display: none !important;
593
+ width: 0 !important;
594
+ height: 0 !important;
595
+ }
596
+
597
+ /* 移除 Group 组件的默认背景 */
598
+ .custom-group {
599
+ border: none !important;
600
+ background: none !important;
601
+ box-shadow: none !important;
602
+ }
603
+
604
+ .custom-group > div {
605
+ border: none !important;
606
+ background: none !important;
607
+ box-shadow: none !important;
608
+ }
609
+
610
+ /* 添加图片容器样式 */
611
+ .images-container {
612
+ margin-top: 20px;
613
+ padding: 10px;
614
+ background: #fff;
615
+ border-radius: 8px;
616
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
617
+ }
618
+
619
+ .images-container img {
620
+ transition: transform 0.3s ease;
621
+ }
622
+
623
+ .images-container img:hover {
624
+ transform: scale(1.05);
625
+ }
626
+
627
+ /* 确保图片容器可见 */
628
+ #component-13 {
629
+ min-height: 200px;
630
+ overflow: visible !important;
631
+ }
632
+ """
633
+
634
+ # 修改 JavaScript 加载状态文本
635
+ js = """
636
+ function showLoading() {
637
+ document.getElementById('loading_status').innerHTML = '<p class="loading">Generating your personalized travel plan...</p>';
638
+ return ['', ''];
639
+ }
640
+ """
641
+
642
+ with gr.Blocks(theme=gr.themes.Soft(), css=css) as interface:
643
+ gr.Markdown("""
644
+ # 🌟 Tourism Planning Assistant 🌟
645
+
646
+ Welcome to the Smart Travel Planning Assistant! Simply input your travel requirements, and we'll generate a personalized travel plan for you.
647
+
648
+ ### Instructions
649
+ 1. Describe your travel needs in the input box (e.g., 'One-day trip to Hong Kong Disneyland')
650
+ 2. Select the number of days for your plan
651
+ 3. Click the "Generate Plan" button
652
+ """)
653
+
654
+ with gr.Row():
655
+ with gr.Column(scale=4):
656
+ llm_provider = gr.Dropdown(
657
+ choices=enabled_providers,
658
+ value=system.config['llm_settings']['default_provider'],
659
+ label="Select LLM Provider"
660
+ )
661
+ llm_model = gr.Dropdown(
662
+ choices=provider_models[system.config['llm_settings']['default_provider']],
663
+ label="Select Model"
664
+ )
665
+
666
+ # 添加更新模型选择的函数
667
+ def update_model_choices(provider):
668
+ return gr.Dropdown(choices=provider_models[provider])
669
+
670
+ # 设置提供商改变时的回调
671
+ llm_provider.change(
672
+ fn=update_model_choices,
673
+ inputs=[llm_provider],
674
+ outputs=[llm_model]
675
+ )
676
+
677
+ query_input = gr.Textbox(
678
+ label="Travel Requirements",
679
+ placeholder="Please enter your travel requirements, e.g.: One-day trip to Hong Kong Disneyland",
680
+ lines=2
681
+ )
682
+ days_input = gr.Slider(
683
+ minimum=1,
684
+ maximum=7,
685
+ value=1,
686
+ step=1,
687
+ label="Number of Days"
688
+ )
689
+
690
+ # 添加显示图片的复选框
691
+ show_images = gr.Checkbox(
692
+ label="Search Related Images",
693
+ value=True,
694
+ info="Whether to search and display related reference images"
695
+ )
696
+
697
+ # 移除 memorag 和 graphrag 选项,只保留 standard
698
+ retrieval_method = gr.Radio(
699
+ choices=["standard"],
700
+ value="standard",
701
+ label="Retrieval Method",
702
+ info="Choose different retrieval strategies",
703
+ visible=False # 由于只有一个选项,可以直接隐藏
704
+ )
705
+
706
+ submit_btn = gr.Button("Generate Plan", variant="primary")
707
+ loading_status = gr.Markdown("", elem_id="loading_status", show_label=False)
708
+
709
+ # 添加图片展示区域到左侧列
710
+ images_container = gr.HTML(
711
+ value="", # 确保初始值为空字符串
712
+ visible=True,
713
+ label="Related Images"
714
+ )
715
+
716
+ # 当复选框状态改变时更新图片区域的显示状态
717
+ show_images.change(
718
+ fn=lambda x: "" if not x else "<div></div>", # 当禁用图片时返回空字符串
719
+ inputs=[show_images],
720
+ outputs=[images_container]
721
+ )
722
+
723
+ with gr.Column(scale=6):
724
+ with gr.Tabs():
725
+ with gr.TabItem("Travel Plan"):
726
+ plan_output = gr.HTML(label="Generated Travel Plan", show_label=False)
727
+ with gr.TabItem("References and Evaluation"):
728
+ sources_output = gr.Markdown(label="References and Evaluation", show_label=False)
729
+
730
+ # 修改示例为英文
731
+ gr.Examples(
732
+ examples=[
733
+ ["One-day trip to Hong Kong Disneyland", 1],
734
+ ["Family trip to Hong Kong Ocean Park", 1],
735
+ ["Hong Kong Shopping and Food Tour", 2],
736
+ ["Hong Kong Cultural Experience Tour", 3]
737
+ ],
738
+ inputs=[query_input, days_input],
739
+ label="Example Queries"
740
+ )
741
+
742
+ def show_loading():
743
+ loading_html = "<div class='loading-container'><p class='loading'>Generating your personalized travel plan...</p></div>"
744
+ return loading_html, loading_html, "", ""
745
+
746
+ def process_with_images(query, days, llm_provider, llm_model, enable_images, retrieval_method):
747
+ plan_html, sources_md, images_html = system.process_query(
748
+ query, days, llm_provider, llm_model,
749
+ enable_images, retrieval_method
750
+ )
751
+
752
+ # 添加调试日志
753
+ logger.info(f"图片HTML长度: {len(images_html) if images_html else 0}")
754
+
755
+ return plan_html, sources_md, images_html
756
+
757
+ # 设置提交按钮事件
758
+ submit_btn.click(
759
+ fn=show_loading,
760
+ inputs=None,
761
+ outputs=[loading_status, plan_output, sources_output, images_container]
762
+ ).then(
763
+ fn=process_with_images,
764
+ inputs=[
765
+ query_input,
766
+ days_input,
767
+ llm_provider,
768
+ llm_model,
769
+ show_images,
770
+ retrieval_method
771
+ ],
772
+ outputs=[plan_output, sources_output, images_container] # 确保顺序正确
773
+ ).then(
774
+ fn=lambda: "",
775
+ inputs=None,
776
+ outputs=[loading_status]
777
+ )
778
+
779
+ # 修改页脚为英文
780
+ gr.Markdown("""
781
+ ### 📝 Notes
782
+ - Plan generation may take some time, please be patient
783
+ - Queries should include specific locations and activity preferences
784
+ - All plans are AI-generated, please adjust according to actual circumstances
785
+
786
+ Powered by RAG for Tourism system © 2024
787
+ """)
788
+
789
+ return interface
790
+
791
+ if __name__ == "__main__":
792
+ demo = create_interface()
793
+ # 使用 Hugging Face Spaces 环境变量
794
+ demo.launch(
795
+ server_name="0.0.0.0",
796
+ server_port=7860,
797
+ share=False, # Hugging Face Spaces 已经提供了公开访问
798
+ debug=False
799
+ )
config/config.yaml ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # API Keys - 在 Hugging Face Spaces 中使用环境变量方式设置
2
+ google_api_key: ${HF_GOOGLE_API_KEY}
3
+ google_cx: ${HF_GOOGLE_CX}
4
+ bing_api_key: ${HF_BING_API_KEY}
5
+ openai_api_key: ${HF_OPENAI_API_KEY}
6
+ deepseek_api_key: ${HF_DEEPSEEK_API_KEY}
7
+ bocha_api_key: ${HF_BOCHA_API_KEY}
8
+ bocha_base_url: "https://api.bochaai.com"
9
+
10
+ # LLM Settings
11
+ llm_settings:
12
+ providers:
13
+ - name: "deepseek"
14
+ enabled: true
15
+ model: "deepseek-chat"
16
+ base_url: "https://api.deepseek.com"
17
+ api_key: ${deepseek_api_key}
18
+ models: ["deepseek-chat"]
19
+ - name: "openai"
20
+ enabled: true
21
+ model: "gpt-4o"
22
+ base_url: "https://api.openai.com/v1" # 使用标准 OpenAI API URL
23
+ api_key: ${openai_api_key}
24
+ models: ["gpt-4o"]
25
+ default_provider: "deepseek" # 默认使用的提供商
26
+
27
+ # 检索设置
28
+ retrieval_settings:
29
+ default_method: "standard"
30
+ methods:
31
+ - name: "standard"
32
+ enabled: true
33
+ model_settings:
34
+ embedding_model: "BAAI/bge-m3" # 使用 Hugging Face 模型 ID
35
+ reranker_model: "BAAI/bge-reranker-large" # 使用 Hugging Face 模型 ID
36
+
37
+ # Search Settings
38
+ max_results: 20
39
+ language: "zh-CN"
40
+ search_provider: "bocha"
41
+
42
+ # Document Processing
43
+ max_passage_length: 500
44
+ min_passage_length: 100
45
+
46
+ # Vector Search Settings
47
+ embedding_model: "BAAI/bge-m3" # 使用 Hugging Face 模型 ID
48
+ reranker_model: "BAAI/bge-reranker-large" # 使用 Hugging Face 模型 ID
49
+ batch_size: 32
50
+ use_gpu: true
51
+
52
+ # Ranking Settings
53
+ initial_top_k: 100
54
+ final_top_k: 3
55
+ retrieval_weight: 0.3
56
+ rerank_weight: 0.7
57
+
58
+ # Search Settings
59
+ search_settings:
60
+ trusted_domains:
61
+ - 'discoverhongkong.com'
62
+ - 'tourism.gov.hk'
63
+ - 'hong-kong-travel.com'
64
+ - 'timeout.com.hk'
65
+ - 'openrice.com'
66
+ - 'lcsd.gov.hk'
67
+ - 'hkpl.gov.hk'
68
+
69
+ proxy:
70
+ enabled: false
71
+ host: "127.0.0.1"
72
+ port: 8880
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ requests>=2.31.0
2
+ beautifulsoup4>=4.12.0
3
+ trafilatura>=1.6.1
4
+ torch>=2.0.0
5
+ transformers>=4.36.0
6
+ openai>=1.3.0
7
+ pyyaml>=6.0.1
8
+ faiss-cpu>=1.7.4
9
+ sentence-transformers>=2.2.0
10
+ gradio>=4.8.0
11
+ neo4j>=5.14.0
12
+ langchain>=0.0.350
13
+ matplotlib>=3.8.0
14
+ Pillow>=9.0.0
15
+ numpy>=1.22.0
16
+ FlagEmbedding>=1.1.5
src/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .core.html_processor import HTMLProcessor
2
+ from .core.document_processor import DocumentProcessor
3
+
4
+ __all__ = ['HTMLProcessor', 'DocumentProcessor']
src/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (291 Bytes). View file
 
src/api/__init__.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any
2
+ from fastapi import FastAPI
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from .llm_api import DeepseekInterface
5
+ from .search_api import GoogleSearch
6
+
7
+ def create_app(config: Dict[str, Any] = None) -> FastAPI:
8
+ """
9
+ 创建并配置 FastAPI 应用
10
+ """
11
+ app = FastAPI(
12
+ title="Travel RAG API",
13
+ description="Travel recommendation system using RAG",
14
+ version="1.0.0"
15
+ )
16
+
17
+ # 配置 CORS
18
+ app.add_middleware(
19
+ CORSMiddleware,
20
+ allow_origins=["*"],
21
+ allow_credentials=True,
22
+ allow_methods=["*"],
23
+ allow_headers=["*"],
24
+ )
25
+
26
+ # 初始化配置
27
+ if config:
28
+ app.state.config = config
29
+
30
+ # 初始化 LLM
31
+ app.state.llm = DeepseekInterface(
32
+ api_key=config['deepseek_api_key'],
33
+ base_url=config['llm_settings']['deepseek']['base_url'],
34
+ model=config['llm_settings']['deepseek']['models'][0]
35
+ )
36
+
37
+ # 初始化搜索引擎在 init_app 中完成
38
+ from .routes import init_app
39
+ init_app(app)
40
+
41
+ # 导入和注册路由
42
+ from .routes import router
43
+ app.include_router(router)
44
+
45
+ return app
46
+
47
+ __all__ = ['create_app']
src/api/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.17 kB). View file
 
src/api/__pycache__/llm_api.cpython-310.pyc ADDED
Binary file (8.51 kB). View file
 
src/api/__pycache__/ollama_api.cpython-310.pyc ADDED
Binary file (3.59 kB). View file
 
src/api/__pycache__/routes.cpython-310.pyc ADDED
Binary file (2.91 kB). View file
 
src/api/__pycache__/search_api.cpython-310.pyc ADDED
Binary file (6.3 kB). View file
 
src/api/llm_api.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict
2
+ from abc import ABC, abstractmethod
3
+ from openai import OpenAI
4
+ import logging
5
+ import httpx
6
+
7
+ class LLMInterface(ABC):
8
+ @abstractmethod
9
+ def generate(self, prompt: str) -> str:
10
+ pass
11
+
12
+ class DeepseekInterface(LLMInterface):
13
+ def __init__(self, api_key: str, base_url: str, model: str):
14
+ self.api_key = api_key
15
+ self.base_url = base_url
16
+ self.model = model
17
+ self.headers = {
18
+ "Authorization": f"Bearer {api_key}",
19
+ "Content-Type": "application/json"
20
+ }
21
+
22
+ def _build_system_prompt(self, role: str) -> str:
23
+ """构建系统提示词"""
24
+ roles = {
25
+ "summarizer": "You are a professional tourism content analyst, good at extracting and summarizing key tourism-related information. Please answer in English",
26
+ "planner": "You are a professional travel planner who is good at making detailed travel plans. Please answer in English"
27
+ }
28
+ return roles.get(role, "You are a professional AI assistant. Please answer in English")
29
+
30
+ def generate(self, prompt: str, role: str = "planner") -> str:
31
+ import requests
32
+ from requests.adapters import HTTPAdapter
33
+ from urllib3.util.retry import Retry
34
+
35
+ session = requests.Session()
36
+ retries = Retry(
37
+ total=3,
38
+ backoff_factor=1,
39
+ status_forcelist=[500, 502, 503, 504]
40
+ )
41
+ session.mount('https://', HTTPAdapter(max_retries=retries))
42
+
43
+ payload = {
44
+ "model": self.model,
45
+ "messages": [
46
+ {
47
+ "role": "system",
48
+ "content": self._build_system_prompt(role)
49
+ },
50
+ {
51
+ "role": "user",
52
+ "content": prompt
53
+ }
54
+ ],
55
+ "temperature": 0.7,
56
+ "max_tokens": 2000
57
+ }
58
+
59
+ try:
60
+ response = session.post(
61
+ f"{self.base_url}/v1/chat/completions",
62
+ headers=self.headers,
63
+ json=payload,
64
+ timeout=(10, 60)
65
+ )
66
+ response.raise_for_status()
67
+ return response.json()['choices'][0]['message']['content']
68
+ except requests.exceptions.Timeout:
69
+ print("Deepseek API request timeout, retrying...")
70
+ return "Sorry, due to network issues, content generation is temporarily unavailable. Please try again later."
71
+ except requests.exceptions.RequestException as e:
72
+ print(f"Error calling Deepseek API: {str(e)}")
73
+ return "Sorry, an error occurred while generating content. Please try again later."
74
+
75
+ def summarize_document(self, content: str, title: str, url: str) -> str:
76
+ """使用 Deepseek 总结文档"""
77
+ prompt = f"""Please analyze the following tourism web content and generate a rich summary paragraph.
78
+
79
+ Web Title: {title}
80
+ Web Link: {url}
81
+
82
+ Web Content:
83
+ {content[:4000]}
84
+
85
+ Requirements:
86
+ 1. The summary should be between 300-500 words
87
+ 2. Keep the most important tourism information (attractions, suggestions, tips, etc.)
88
+ 3. Use an objective tone
89
+ 4. Information should be accurate and practical
90
+ 5. Remove marketing and advertising content
91
+ 6. Maintain logical coherence
92
+
93
+ Please return the summary content directly, without any other explanation."""
94
+
95
+ return self.generate(prompt, role="summarizer")
96
+
97
+ def generate_travel_plan(self, query: str, context: List[Dict]) -> str:
98
+ # 构建更结构化的上下文
99
+ context_text = "\n\n".join([
100
+ f"Source {i+1} ({doc.get('title', 'Unknown Title')}):\n{doc['passage']}"
101
+ for i, doc in enumerate(context)
102
+ ])
103
+
104
+ prompt = f"""As a professional travel planner, please create a detailed travel plan based on the user's needs and reference materials.
105
+
106
+ User Needs: {query}
107
+
108
+ Reference Materials:
109
+ {context_text}
110
+
111
+ Please provide the following content:
112
+ 1. Itinerary Overview (Overall arrangement and key attractions)
113
+ 2. Daily detailed itinerary (includes specific time, location, and transportation methods)
114
+ 3. Traffic suggestions (includes practical APP recommendations)
115
+ 4. Accommodation recommendations (includes specific areas and hotel suggestions)
116
+ 5. Food recommendations (includes specialty restaurants and snacks)
117
+ 6. Practical tips (weather, clothing, essential items, etc.)
118
+
119
+ Requirements:
120
+ 1. The itinerary should be reasonable, considering the distance between attractions
121
+ 2. Provide specific time points
122
+ 3. Include detailed traffic guidance
123
+ 4. Suggestions should be specific and practical
124
+ 5. Consider actual conditions (e.g., opening hours of attractions)
125
+
126
+ Please return the travel plan content directly, without any other explanation."""
127
+
128
+ return self.generate(prompt, role="planner")
129
+
130
+ class OllamaInterface(LLMInterface):
131
+ def __init__(self, base_url: str, model: str):
132
+ self.base_url = base_url.rstrip('/')
133
+ self.model = model
134
+ self.headers = {
135
+ "Content-Type": "application/json"
136
+ }
137
+
138
+ def _build_system_prompt(self, role: str) -> str:
139
+ """构建系统提示词"""
140
+ roles = {
141
+ "summarizer": "You are a professional tourism content analyst, good at extracting and summarizing key tourism-related information. Please answer in English",
142
+ "planner": "You are a professional travel planner who is good at making detailed travel plans. Please answer in English"
143
+ }
144
+ return roles.get(role, "You are a professional AI assistant. Please answer in English")
145
+
146
+ def generate(self, prompt: str, role: str = "planner") -> str:
147
+ import requests
148
+
149
+ payload = {
150
+ "model": self.model,
151
+ "messages": [
152
+ {
153
+ "role": "system",
154
+ "content": self._build_system_prompt(role)
155
+ },
156
+ {
157
+ "role": "user",
158
+ "content": prompt
159
+ }
160
+ ],
161
+ "stream": False
162
+ }
163
+
164
+ try:
165
+ response = requests.post(
166
+ f"{self.base_url}/api/chat",
167
+ headers=self.headers,
168
+ json=payload,
169
+ timeout=(10, 60)
170
+ )
171
+ response.raise_for_status()
172
+ return response.json()['message']['content']
173
+ except Exception as e:
174
+ print(f"Error calling Ollama API: {str(e)}")
175
+ return "Sorry, an error occurred while generating content. Please try again later."
176
+
177
+ def summarize_document(self, content: str, title: str, url: str) -> str:
178
+ """使用 Ollama 总结文档"""
179
+ prompt = f"""Please analyze the following tourism web content and generate a rich summary paragraph.
180
+
181
+ Web Title: {title}
182
+ Web Link: {url}
183
+
184
+ Web Content:
185
+ {content[:4000]}
186
+
187
+ Requirements:
188
+ 1. The summary should be between 300-500 words
189
+ 2. Keep the most important tourism information (attractions, suggestions, tips, etc.)
190
+ 3. Use an objective tone
191
+ 4. Information should be accurate and practical
192
+ 5. Remove marketing and advertising content
193
+ 6. Maintain logical coherence
194
+
195
+ Please return the summary content directly, without any other explanation."""
196
+
197
+ return self.generate(prompt, role="summarizer")
198
+
199
+ def generate_travel_plan(self, query: str, context: List[Dict]) -> str:
200
+ # 构建更结构化的上下文
201
+ context_text = "\n\n".join([
202
+ f"来源 {i+1} ({doc.get('title', '未知标题')}):\n{doc['passage']}"
203
+ for i, doc in enumerate(context)
204
+ ])
205
+
206
+ prompt = f"""As a professional travel planner, please create a detailed travel plan based on the user's needs and reference materials.
207
+
208
+ User Needs: {query}
209
+
210
+ Reference Materials:
211
+ {context_text}
212
+
213
+ Please provide the following content:
214
+ 1. Itinerary Overview (Overall arrangement and key attractions)
215
+ 2. Daily detailed itinerary (includes specific time, location, and transportation methods)
216
+ 3. Traffic suggestions (includes practical APP recommendations)
217
+ 4. Accommodation recommendations (includes specific areas and hotel suggestions)
218
+ 5. Food recommendations (includes specialty restaurants and snacks)
219
+ 6. Practical tips (weather, clothing, essential items, etc.)
220
+
221
+ Requirements:
222
+ 1. The itinerary should be reasonable, considering the distance between attractions
223
+ 2. Provide specific time points
224
+ 3. Include detailed traffic guidance
225
+ 4. Suggestions should be specific and practical
226
+ 5. Consider actual conditions (e.g., opening hours of attractions)
227
+
228
+ Please return the travel plan content directly, without any other explanation."""
229
+
230
+ return self.generate(prompt, role="planner")
231
+
232
+ class OpenAIInterface(LLMInterface):
233
+ def __init__(self, api_key: str, model: str = "gpt-4o", base_url: str = "https://api.feidaapi.com/v1"):
234
+ self.api_key = api_key
235
+ self.model = model
236
+ self.client = OpenAI(api_key=api_key, base_url=base_url)
237
+
238
+ def _build_system_prompt(self, role: str) -> str:
239
+ """构建系统提示词"""
240
+ roles = {
241
+ "summarizer": "You are a professional tourism content analyst, good at extracting and summarizing key tourism-related information. Please answer in English",
242
+ "planner": "You are a professional travel planner who is good at making detailed travel plans. Please answer in English"
243
+ }
244
+ return roles.get(role, "You are a professional AI assistant. Please answer in English")
245
+
246
+ def generate(self, prompt: str, role: str = "planner") -> str:
247
+ try:
248
+ messages = [
249
+ {"role": "system", "content": self._build_system_prompt(role)},
250
+ {"role": "user", "content": prompt}
251
+ ]
252
+
253
+ response = self.client.chat.completions.create(
254
+ model=self.model,
255
+ messages=messages,
256
+ temperature=0.7,
257
+ max_tokens=2000
258
+ )
259
+
260
+ return response.choices[0].message.content
261
+
262
+ except Exception as e:
263
+ logging.error(f"Error calling OpenAI API: {str(e)}")
264
+ return "Sorry, an error occurred while generating content. Please try again later."
265
+
266
+ def summarize_document(self, content: str, title: str, url: str) -> str:
267
+ """使用 OpenAI 总结文档"""
268
+ prompt = f"""Please analyze the following tourism web content and generate a rich summary paragraph.
269
+
270
+ Web Title: {title}
271
+ Web Link: {url}
272
+
273
+ Web Content:
274
+ {content[:4000]}
275
+
276
+ Requirements:
277
+ 1. The summary should be between 300-500 words
278
+ 2. Keep the most important tourism information (attractions, suggestions, tips, etc.)
279
+ 3. Use an objective tone
280
+ 4. Information should be accurate and practical
281
+ 5. Remove marketing and advertising content
282
+ 6. Maintain logical coherence
283
+
284
+ Please return the summary content directly, without any other explanation."""
285
+
286
+ return self.generate(prompt, role="summarizer")
287
+
288
+ def generate_travel_plan(self, query: str, context: List[Dict]) -> str:
289
+ # 构建更结构化的上下文
290
+ context_text = "\n\n".join([
291
+ f"Source {i+1} ({doc.get('title', 'Unknown Title')}):\n{doc['passage']}"
292
+ for i, doc in enumerate(context)
293
+ ])
294
+
295
+ prompt = f"""As a professional travel planner, please create a detailed travel plan based on the user's needs and reference materials.
296
+
297
+ User Needs: {query}
298
+
299
+ Reference Materials:
300
+ {context_text}
301
+
302
+ Please provide the following content:
303
+ 1. Itinerary Overview (Overall arrangement and key attractions)
304
+ 2. Daily detailed itinerary (includes specific time, location, and transportation methods)
305
+ 3. Traffic suggestions (includes practical APP recommendations)
306
+ 4. Accommodation recommendations (includes specific areas and hotel suggestions)
307
+ 5. Food recommendations (includes specialty restaurants and snacks)
308
+ 6. Practical tips (weather, clothing, essential items, etc.)
309
+
310
+ Requirements:
311
+ 1. The itinerary should be reasonable, considering the distance between attractions
312
+ 2. Provide specific time points
313
+ 3. Include detailed traffic guidance
314
+ 4. Suggestions should be specific and practical
315
+ 5. Consider actual conditions (e.g., opening hours of attractions)
316
+
317
+ Please return the travel plan content directly, without any other explanation."""
318
+
319
+ return self.generate(prompt, role="planner")
src/api/ollama_api.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict
2
+ from .llm_api import LLMInterface
3
+ import requests
4
+ import json
5
+
6
+ class OllamaInterface(LLMInterface):
7
+ def __init__(self, model_name: str = "qwen2.5:32b_ctx32k", base_url: str = "http://localhost:11434"):
8
+ self.model_name = model_name
9
+ self.base_url = base_url
10
+
11
+ def generate(self, prompt: str) -> str:
12
+ """使用 Ollama 生成响应"""
13
+ try:
14
+ response = requests.post(
15
+ f"{self.base_url}/api/generate",
16
+ json={
17
+ "model": self.model_name,
18
+ "prompt": prompt,
19
+ "stream": False,
20
+ "options": {
21
+ "temperature": 0.7,
22
+ "top_p": 0.9,
23
+ "top_k": 40,
24
+ }
25
+ }
26
+ )
27
+ response.raise_for_status()
28
+ return response.json()['response']
29
+ except Exception as e:
30
+ print(f"Ollama 生成错误: {e}")
31
+ return ""
32
+
33
+ def _build_system_prompt(self, role: str) -> str:
34
+ """构建系统提示词"""
35
+ roles = {
36
+ "summarizer": "你是一个专业的旅游内容分析师,擅长提取和总结旅游相关的关键信息。",
37
+ "planner": "你是一个专业的旅行规划师,擅长制定详细的旅行计划。"
38
+ }
39
+ return roles.get(role, "你是一个专业的AI助手。")
40
+
41
+ def summarize_document(self, content: str, title: str, url: str) -> str:
42
+ """使用 Qwen 2.5 总结文档"""
43
+ system_prompt = self._build_system_prompt("summarizer")
44
+ prompt = f"""{system_prompt}
45
+
46
+ 请分析以下旅游网页内容,生成一个信息丰富的总结段落。
47
+
48
+ 网页标题:{title}
49
+ 网页链接:{url}
50
+
51
+ 网页内容:
52
+ {content[:4000]}
53
+
54
+ 要求:
55
+ 1. 总结长度控制在300-500字
56
+ 2. 保留最重要的旅游信息(景点、建议、提示等)
57
+ 3. 使用客观的语气
58
+ 4. 信息准确且实用
59
+ 5. 去除营销和广告内容
60
+ 6. 保持逻辑连贯性
61
+
62
+ 请直接返回总结内容,不需要其他说明。"""
63
+
64
+ return self.generate(prompt)
65
+
66
+ def generate_travel_plan(self, query: str, context: List[Dict]) -> str:
67
+ """使用 Qwen 2.5 生成旅行计划"""
68
+ system_prompt = self._build_system_prompt("planner")
69
+
70
+ context_text = "\n\n".join([
71
+ f"来源 {i+1}:\n{doc['passage']}"
72
+ for i, doc in enumerate(context)
73
+ ])
74
+
75
+ prompt = f"""{system_prompt}
76
+
77
+ 请根据以下信息,为用户制定一个详细的旅行计划。
78
+
79
+ 用户需求:{query}
80
+
81
+ 参考信息:
82
+ {context_text}
83
+
84
+ 请提供以下内容:
85
+ 1. 行程概览
86
+ 2. 每日详细行程安排
87
+ 3. 交通建议
88
+ 4. 住宿推荐
89
+ 5. 美食推荐
90
+ 6. 注意事项和小贴士
91
+
92
+ 要求:
93
+ 1. 计划要详细且实用
94
+ 2. 时间安排要合理
95
+ 3. 建议要具体
96
+ 4. 考虑实际情况(如交通时间、景点开放时间等)
97
+ 5. 可以根据上下文补充合理的细节
98
+
99
+ 请直接返回旅行计划内容,不需要其他说明。"""
100
+
101
+ return self.generate(prompt)
src/api/routes.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import APIRouter, HTTPException, Request
2
+ from typing import Dict, Any
3
+ from pydantic import BaseModel
4
+ from src.core.document_processor import DocumentProcessor
5
+ from src.core.ranking import RankingSystem
6
+ from src.core.plan_generator import PlanGenerator
7
+ from src.core.embeddings import EmbeddingModel
8
+ from src.core.reranker import Reranker
9
+ from src.api.search_api import GoogleSearch, BochaSearch
10
+ from src.utils.helpers import setup_proxy
11
+ from src.api.llm_api import DeepseekInterface
12
+ import logging
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # 创建路由器
17
+ router = APIRouter(
18
+ tags=["travel"] # Swagger UI 的标签
19
+ )
20
+
21
+ # 添加根路由
22
+ @router.get("/")
23
+ async def root():
24
+ return {
25
+ "message": "欢迎使用旅游推荐系统 API",
26
+ "status": "运行正常",
27
+ "version": "1.0.0",
28
+ "endpoints": {
29
+ "健康检查": "/health",
30
+ "旅游推荐": "/api/v1/recommend",
31
+ "API文档": "/docs"
32
+ }
33
+ }
34
+
35
+ # 请求模型
36
+ class TravelQuery(BaseModel):
37
+ query: str
38
+ location: str = None
39
+ max_results: int = 10
40
+
41
+ # 响应模型
42
+ class TravelResponse(BaseModel):
43
+ recommendations: list
44
+ query: str
45
+ metadata: Dict[str, Any]
46
+
47
+ def init_app(app):
48
+ """初始化应用"""
49
+ # 设置代理并获取代理状态
50
+ proxies, proxy_available = setup_proxy(app.state.config)
51
+ app.state.proxies = proxies
52
+
53
+ # 根据代理状态选择搜索引擎
54
+ if proxy_available:
55
+ app.state.search = GoogleSearch(
56
+ api_key=app.state.config['google_api_key'],
57
+ cx=app.state.config['google_cx'],
58
+ proxies=proxies
59
+ )
60
+ logging.info("使用 Google 搜索引擎")
61
+ else:
62
+ app.state.search = BochaSearch(
63
+ api_key=app.state.config['bocha_api_key'],
64
+ base_url=app.state.config['bocha_base_url']
65
+ )
66
+ logging.info("使用博查搜索引擎")
67
+
68
+ # 初始化 Deepseek LLM
69
+ app.state.llm = DeepseekInterface(
70
+ api_key=app.state.config['deepseek_api_key'],
71
+ base_url=app.state.config['llm_settings']['deepseek']['base_url'],
72
+ model=app.state.config['llm_settings']['deepseek']['models'][0]
73
+ )
74
+
75
+ # ... 其他初始化代码 ...
76
+
77
+ @router.post("/api/v1/recommend", response_model=TravelResponse)
78
+ async def get_travel_recommendations(query: TravelQuery, request: Request):
79
+ """
80
+ 获取旅游推荐
81
+ """
82
+ logger.info(f"收到查询请求: {query.dict()}")
83
+ try:
84
+ # 使用已配置代理的搜索实例
85
+ search = request.app.state.search
86
+ llm = request.app.state.llm
87
+
88
+ # 执行搜索
89
+ logger.info("开始执行搜索...")
90
+ search_results = search.search(query.query)
91
+ logger.info(f"搜索完成,获得 {len(search_results)} 条结果")
92
+
93
+ # 处理文档
94
+ doc_processor = DocumentProcessor(llm)
95
+ passages = doc_processor.process_documents(search_results)
96
+ passages = [{'passage': p} for p in passages]
97
+ logging.info(f"Passages structure: {passages[:1]}") # 打印第一个元素的结构
98
+
99
+ # 初始化排序系统
100
+ embedding_model = EmbeddingModel("BAAI/bge-m3")
101
+ reranker = Reranker("BAAI/bge-reranker-large")
102
+ ranking_system = RankingSystem(embedding_model, reranker)
103
+
104
+ # 两阶段排序
105
+ initial_ranked = ranking_system.initial_ranking(
106
+ query.query,
107
+ passages,
108
+ 10 # initial_top_k
109
+ )
110
+
111
+ final_ranked = ranking_system.rerank(
112
+ query.query,
113
+ initial_ranked,
114
+ 3 # final_top_k
115
+ )
116
+
117
+ # 生成计划
118
+ plan_generator = PlanGenerator(llm)
119
+ final_plan = plan_generator.generate_plan(query.query, final_ranked)
120
+
121
+ return TravelResponse(
122
+ recommendations=[final_plan['plan']],
123
+ query=query.query,
124
+ metadata={
125
+ "location": query.location,
126
+ "max_results": query.max_results,
127
+ "sources": final_plan['sources']
128
+ }
129
+ )
130
+
131
+ except Exception as e:
132
+ logger.error(f"处理请求时发生错误: {str(e)}", exc_info=True)
133
+ raise HTTPException(status_code=500, detail=str(e))
134
+
135
+ # 健康检查端点
136
+ @router.get("/health")
137
+ async def health_check():
138
+ """
139
+ API 健康检查端点
140
+ """
141
+ return {"status": "healthy"}
src/api/search_api.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict
2
+ import requests
3
+ from abc import ABC, abstractmethod
4
+ import logging
5
+ import json
6
+
7
+ logging.basicConfig(level=logging.DEBUG)
8
+
9
+ class SearchEngine(ABC):
10
+ def __init__(self):
11
+ # 广告域名黑名单
12
+ self.ad_domains = {
13
+ 'ads.google.com',
14
+ 'doubleclick.net',
15
+ 'affiliate.',
16
+ '.ads.',
17
+ 'promotion.',
18
+ 'sponsored.',
19
+ 'partner.',
20
+ 'tracking.',
21
+ '.shop.',
22
+ 'taobao.com',
23
+ 'tmall.com',
24
+ 'jd.com',
25
+ 'mafengwo.cn', # 蚂蜂窝
26
+ 'ctrip.com', # 携程
27
+ 'tour.aoyou.com', # 同程
28
+ 'wannar.com' # 玩哪儿
29
+ }
30
+
31
+ def is_ad_url(self, url: str) -> bool:
32
+ """检查URL是否为广告链接"""
33
+ url_lower = url.lower()
34
+ return any(ad_domain in url_lower for ad_domain in self.ad_domains)
35
+
36
+ def enhance_query(self, query: str) -> str:
37
+ """增强查询词,添加香港旅游关键词"""
38
+ if "Hong Kong" not in query:
39
+ query = f"Hong Kong Tourism{query}"
40
+ return query
41
+
42
+ @abstractmethod
43
+ def search(self, query: str) -> List[Dict]:
44
+ pass
45
+
46
+ class GoogleSearch(SearchEngine):
47
+ def __init__(self, api_key: str, cx: str, proxies: Dict[str, str] = None):
48
+ super().__init__()
49
+ self.api_key = api_key
50
+ self.cx = cx
51
+ self.base_url = "https://www.googleapis.com/customsearch/v1"
52
+ self.proxies = proxies or {}
53
+
54
+ def filter_results(self, results: List[Dict]) -> List[Dict]:
55
+ """过滤搜索结果"""
56
+ filtered = []
57
+ for result in results:
58
+ url = result['url'].lower()
59
+ # 只过滤广告域名
60
+ if not self.is_ad_url(url):
61
+ filtered.append(result)
62
+ return filtered
63
+
64
+ def search(self, query: str) -> List[Dict]:
65
+ # 增强查询词
66
+ enhanced_query = self.enhance_query(query)
67
+
68
+ params = {
69
+ 'key': self.api_key,
70
+ 'cx': self.cx,
71
+ 'q': enhanced_query
72
+ }
73
+ response = requests.get(self.base_url, params=params)
74
+ if response.status_code == 200:
75
+ results = response.json()
76
+ return [{
77
+ 'title': item['title'],
78
+ 'snippet': item['snippet'],
79
+ 'url': item['link']
80
+ } for item in results.get('items', [])]
81
+ return []
82
+
83
+ class BochaSearch(SearchEngine):
84
+ def __init__(self, api_key: str, base_url: str, proxies: Dict[str, str] = None):
85
+ super().__init__()
86
+ self.api_key = api_key
87
+ self.base_url = base_url.rstrip('/') # 移除末尾可能的斜杠
88
+ self.proxies = proxies or {}
89
+
90
+ def search(self, query: str) -> List[Dict]:
91
+ try:
92
+ # 增强查询词
93
+ enhanced_query = self.enhance_query(query)
94
+
95
+ headers = {
96
+ 'Authorization': f'Bearer {self.api_key}',
97
+ 'Content-Type': 'application/json',
98
+ 'Connection': 'keep-alive',
99
+ 'Accept': '*/*'
100
+ }
101
+
102
+ payload = {
103
+ 'query': enhanced_query,
104
+ 'stream': False # 使用非流式返回
105
+ }
106
+
107
+ # 使用正确的端点
108
+ endpoint = f"{self.base_url}/v1/ai-search"
109
+
110
+ logging.info(f"正在请求博查API...")
111
+ logging.info(f"增强后的查询词: {enhanced_query}")
112
+
113
+ response = requests.post(
114
+ endpoint,
115
+ headers=headers,
116
+ json=payload,
117
+ proxies=None
118
+ )
119
+
120
+ # 详细打印响应信息
121
+ logging.info(f"API响应状态码: {response.status_code}")
122
+ logging.info(f"API响应内容: {response.text[:500]}...") # 只打印前500个字符
123
+
124
+ if response.status_code != 200:
125
+ logging.error(f"API请求失败,状态码: {response.status_code}")
126
+ logging.error(f"错误响应: {response.text}")
127
+ return []
128
+
129
+ response_json = response.json()
130
+ if response_json.get('code') == 200 and 'messages' in response_json:
131
+ messages = response_json['messages']
132
+ if messages and isinstance(messages, list):
133
+ for msg in messages:
134
+ if msg.get('type') == 'source' and msg.get('content_type') == 'webpage':
135
+ try:
136
+ content = json.loads(msg['content'])
137
+ if 'value' in content:
138
+ return content['value']
139
+ except json.JSONDecodeError:
140
+ logging.error(f"无法解析消息内容: {msg['content']}")
141
+ continue
142
+
143
+ logging.error(f"API返回数据结构异常: {response_json}")
144
+ return []
145
+ except Exception as e:
146
+ logging.error(f"处理API响应时出错: {str(e)}")
147
+ return []
148
+
149
+ def search_images(self, query: str, count: int = 3) -> List[Dict]:
150
+ """搜索相关图片"""
151
+ try:
152
+ headers = {
153
+ 'Authorization': f'Bearer {self.api_key}',
154
+ 'Content-Type': 'application/json'
155
+ }
156
+
157
+ # 增强查询词
158
+ enhanced_query = self.enhance_query(query)
159
+ logging.info(f"增强后的图片搜索查询: {enhanced_query}")
160
+
161
+ payload = {
162
+ 'query': enhanced_query,
163
+ 'freshness': 'oneYear',
164
+ 'count': 10, # 搜索更多图片以确保有足够的有效结果
165
+ 'filter': 'images'
166
+ }
167
+
168
+ endpoint = f"{self.base_url}/v1/web-search"
169
+
170
+ response = requests.post(
171
+ endpoint,
172
+ headers=headers,
173
+ json=payload,
174
+ timeout=10
175
+ )
176
+
177
+ if response.status_code == 200:
178
+ try:
179
+ data = response.json()
180
+ logging.info(f"API返回数据结构: {data.keys()}")
181
+
182
+ if data.get('code') == 200 and 'data' in data:
183
+ data_content = data['data']
184
+ logging.info(f"data字段内容: {data_content.keys()}")
185
+
186
+ images = []
187
+ if 'images' in data_content:
188
+ image_items = data_content['images'].get('value', [])
189
+ logging.info(f"找到 {len(image_items)} 张图片")
190
+
191
+ for item in image_items:
192
+ # 简化过滤条件,只检查基本必要条件
193
+ if (item.get('contentUrl') and
194
+ item.get('width', 0) >= 300 and
195
+ item.get('height', 0) >= 300):
196
+
197
+ image_info = {
198
+ 'url': item['contentUrl'],
199
+ 'width': item['width'],
200
+ 'height': item['height']
201
+ }
202
+ images.append(image_info)
203
+ if len(images) >= count:
204
+ break
205
+
206
+ logging.info(f"最终返回 {len(images)} 张图片")
207
+ return images[:count]
208
+
209
+ except json.JSONDecodeError as e:
210
+ logging.error(f"JSON解析错误: {str(e)}")
211
+ return []
212
+ except Exception as e:
213
+ logging.error(f"处理图片数据时出错: {str(e)}")
214
+ return []
215
+
216
+ logging.error(f"API请求失败,状态码: {response.status_code}")
217
+ return []
218
+
219
+ except Exception as e:
220
+ logging.error(f"图片搜索出错: {str(e)}")
221
+ return []
222
+
223
+ """
224
+ class BingSearch(SearchEngine):
225
+ def __init__(self, api_key: str):
226
+ super().__init__()
227
+ self.api_key = api_key
228
+ self.base_url = "https://api.bing.microsoft.com/v7.0/search"
229
+
230
+ def search(self, query: str) -> List[Dict]:
231
+ # 只添加香港旅游关键词
232
+ enhanced_query = f"香港旅游 {query}"
233
+
234
+ headers = {'Ocp-Apim-Subscription-Key': self.api_key}
235
+ params = {
236
+ 'q': enhanced_query
237
+ }
238
+
239
+ response = requests.get(
240
+ self.base_url,
241
+ headers=headers,
242
+ params=params
243
+ )
244
+ results = response.json()
245
+
246
+ filtered_results = []
247
+ for item in results.get('webPages', {}).get('value', []):
248
+ if not self.is_ad_url(item['url']):
249
+ filtered_results.append({
250
+ 'title': item['name'],
251
+ 'snippet': item['snippet'],
252
+ 'url': item['url']
253
+ })
254
+
255
+ return filtered_results
256
+
257
+ def is_trusted_domain(self, url: str) -> bool:
258
+ ""检查是否为可信域名""
259
+ return any(
260
+ trusted_domain in url.lower()
261
+ for trusted_domain in self.config['search_settings']['trusted_domains']
262
+ )
263
+ """
src/core/__pycache__/document_processor.cpython-310.pyc ADDED
Binary file (2.4 kB). View file
 
src/core/__pycache__/embeddings.cpython-310.pyc ADDED
Binary file (1.77 kB). View file
 
src/core/__pycache__/html_processor.cpython-310.pyc ADDED
Binary file (5.69 kB). View file
 
src/core/__pycache__/plan_generator.cpython-310.pyc ADDED
Binary file (3.5 kB). View file
 
src/core/__pycache__/ranking.cpython-310.pyc ADDED
Binary file (3.69 kB). View file
 
src/core/__pycache__/reranker.cpython-310.pyc ADDED
Binary file (1.87 kB). View file
 
src/core/_init_.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Core package initialization.
3
+ This module contains the core functionality for the travel RAG system.
4
+ """
5
+
6
+ from .document_processor import DocumentProcessor
7
+ from .ranking import RankingSystem
8
+ from .plan_generator import PlanGenerator
9
+
10
+ __all__ = [
11
+ 'DocumentProcessor',
12
+ 'RankingSystem',
13
+ 'PlanGenerator'
14
+ ]
src/core/document_processor.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict
2
+ from src.core.html_processor import HTMLProcessor
3
+ from src.api.llm_api import LLMInterface
4
+ import logging
5
+ from concurrent.futures import ThreadPoolExecutor, as_completed
6
+
7
+ class DocumentProcessor:
8
+ def __init__(self, llm: LLMInterface):
9
+ self.html_processor = HTMLProcessor()
10
+ self.llm = llm
11
+ # 添加缓存
12
+ self.cache = {}
13
+
14
+ def _clean_text(self, text: str) -> str:
15
+ """清理文本内容"""
16
+ import re
17
+ # 移除多余空白
18
+ text = re.sub(r'\s+', ' ', text)
19
+ # 移除特殊字符
20
+ text = re.sub(r'[^\w\s\u4e00-\u9fff。,!?、]', '', text)
21
+ return text.strip()
22
+
23
+ def process_documents(self, search_results: List[Dict]) -> List[Dict]:
24
+ processed_docs = []
25
+ batch_size = 5 # 批处理大小
26
+
27
+ # 并行处理文档
28
+ with ThreadPoolExecutor(max_workers=5) as executor:
29
+ futures = []
30
+ for result in search_results:
31
+ if result['url'] in self.cache:
32
+ processed_docs.append(self.cache[result['url']])
33
+ continue
34
+
35
+ futures.append(
36
+ executor.submit(self._process_single_doc, result)
37
+ )
38
+
39
+ for future in as_completed(futures):
40
+ try:
41
+ doc = future.result()
42
+ if doc:
43
+ self.cache[doc['url']] = doc
44
+ processed_docs.append(doc)
45
+ except Exception as e:
46
+ logging.error(f"处理文档失败: {str(e)}")
47
+
48
+ return processed_docs[:5] # 限制返回数量
49
+
50
+ def _process_single_doc(self, result: Dict) -> Dict:
51
+ try:
52
+ # 提取HTML内容
53
+ html = self.html_processor.fetch_html(result['url'])
54
+ if not html:
55
+ return None
56
+
57
+ # 提取主要内容
58
+ content = self.html_processor.extract_main_content(html)
59
+ content = self._clean_text(content)
60
+
61
+ if len(content) < 100: # 内容太短
62
+ return None
63
+
64
+ # 生成更有针对性的总结
65
+ summary = self.llm.summarize_document(
66
+ content=content,
67
+ title=result.get('title', ''),
68
+ url=result['url']
69
+ )
70
+
71
+ if summary:
72
+ return {
73
+ 'passage': summary,
74
+ 'title': result.get('title', ''),
75
+ 'url': result['url']
76
+ }
77
+
78
+ except Exception as e:
79
+ logging.error(f"处理文档失败 ({result['url']}): {str(e)}")
80
+ return None
src/core/embeddings.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ import torch
3
+ import numpy as np
4
+ import os
5
+ import logging
6
+
7
+ class EmbeddingModel:
8
+ def __init__(self, model_name="BAAI/bge-m3"):
9
+ try:
10
+ # 使用 Hugging Face 模型 ID
11
+ self.model = SentenceTransformer(model_name)
12
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ self.model.to(self.device)
14
+ logging.info(f"成功加载嵌入模型 {model_name} 到 {self.device} 设备")
15
+ except Exception as e:
16
+ logging.error(f"加载模型失败: {str(e)}")
17
+ raise
18
+
19
+ def encode(self, texts, batch_size=32):
20
+ """
21
+ 将文本转换为向量表示
22
+ """
23
+ embeddings = self.model.encode(
24
+ texts,
25
+ batch_size=batch_size,
26
+ show_progress_bar=True,
27
+ normalize_embeddings=True
28
+ )
29
+ return embeddings
30
+
31
+ def encode_queries(self, queries):
32
+ """
33
+ 为查询文本添加特殊前缀并编码
34
+ BGE模型推荐在查询前添加"Represent this sentence for searching relevant passages: "
35
+ """
36
+ prefix = "Represent this sentence for searching relevant passages: "
37
+ if isinstance(queries, str):
38
+ queries = [queries]
39
+
40
+ prefixed_queries = [prefix + query for query in queries]
41
+ return self.encode(prefixed_queries)
src/core/html_processor.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from bs4 import BeautifulSoup
3
+ from concurrent.futures import ThreadPoolExecutor, as_completed
4
+ from typing import List, Dict
5
+ import time
6
+ import logging
7
+
8
+ class HTMLProcessor:
9
+ def __init__(self, timeout: int = 5):
10
+ self.session = requests.Session()
11
+ self.timeout = timeout
12
+
13
+ def fetch_html(self, url: str) -> str:
14
+ """获取单个URL的HTML内容"""
15
+ headers = {
16
+ 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
17
+ 'Accept': 'text/html,application/xhtml+xml,application/xml;q=0.9,image/webp,*/*;q=0.8',
18
+ 'Accept-Language': 'zh-CN,zh;q=0.9,en;q=0.8',
19
+ 'Accept-Encoding': 'gzip, deflate, br',
20
+ 'Connection': 'keep-alive'
21
+ }
22
+
23
+ try:
24
+ logging.info(f"开始获取URL: {url}")
25
+ response = self.session.get(
26
+ url,
27
+ timeout=self.timeout,
28
+ headers=headers,
29
+ verify=False
30
+ )
31
+ response.raise_for_status()
32
+
33
+ # 检查响应内容类型
34
+ content_type = response.headers.get('content-type', '')
35
+ if 'text/html' not in content_type.lower():
36
+ logging.warning(f"非HTML响应: {content_type}")
37
+
38
+ # 设置正确的编码
39
+ response.encoding = response.apparent_encoding
40
+
41
+ html = response.text
42
+ logging.info(f"成功获取HTML,长度: {len(html)}")
43
+
44
+ return html
45
+
46
+ except requests.Timeout:
47
+ logging.error(f"获取URL超时: {url}")
48
+ except requests.RequestException as e:
49
+ logging.error(f"获取URL失败 {url}: {str(e)}")
50
+ except Exception as e:
51
+ logging.error(f"未预期的错误 {url}: {str(e)}")
52
+
53
+ return ""
54
+
55
+ def fetch_multiple_html(self, urls: List[str], max_urls: int = 10) -> List[Dict]:
56
+ """
57
+ 并行获取多个URL的HTML内容
58
+
59
+ Args:
60
+ urls: URL列表
61
+ max_urls: 最大获取数量
62
+
63
+ Returns:
64
+ List[Dict]: 包含成功获取的HTML内容列表
65
+ """
66
+ results = []
67
+ urls = urls[:max_urls] # 只处理前max_urls个URL
68
+
69
+ with ThreadPoolExecutor(max_workers=max_urls) as executor:
70
+ # 提交所有任务
71
+ future_to_url = {
72
+ executor.submit(self.fetch_html, url): url
73
+ for url in urls
74
+ }
75
+
76
+ # 处理完成的任务
77
+ for future in as_completed(future_to_url):
78
+ url = future_to_url[future]
79
+ try:
80
+ html = future.result()
81
+ if html: # 只添加成功获取的结果
82
+ results.append({
83
+ 'url': url,
84
+ 'html': html,
85
+ 'metadata': self.extract_metadata(html)
86
+ })
87
+ except Exception as e:
88
+ print(f"处理URL失败 {url}: {e}")
89
+
90
+ return results
91
+
92
+ def extract_main_content(self, html: str) -> str:
93
+ """提取HTML中的主要内容"""
94
+ if not html:
95
+ logging.warning("输入的HTML为空")
96
+ return ""
97
+
98
+ try:
99
+ soup = BeautifulSoup(html, 'html.parser')
100
+
101
+ # 移除脚本和样式元素
102
+ for script in soup(["script", "style", "iframe", "nav", "footer", "header"]):
103
+ script.decompose()
104
+
105
+ # 记录原始长度
106
+ original_length = len(html)
107
+
108
+ # 尝试找到主要内容容器
109
+ main_content = None
110
+ possible_content_ids = ['content', 'main', 'article', 'post']
111
+ possible_content_classes = ['content', 'article', 'post', 'main-content']
112
+
113
+ # 按ID查找
114
+ for content_id in possible_content_ids:
115
+ main_content = soup.find(id=content_id)
116
+ if main_content:
117
+ break
118
+
119
+ # 按class查找
120
+ if not main_content:
121
+ for content_class in possible_content_classes:
122
+ main_content = soup.find(class_=content_class)
123
+ if main_content:
124
+ break
125
+
126
+ # 如果找不到特定容器,使用全文
127
+ text = main_content.get_text() if main_content else soup.get_text()
128
+
129
+ # 清理文本
130
+ lines = (line.strip() for line in text.splitlines())
131
+ chunks = (phrase.strip() for line in lines for phrase in line.split(" "))
132
+ text = '\n'.join(chunk for chunk in chunks if chunk)
133
+
134
+ # 记���处理后长度
135
+ processed_length = len(text)
136
+
137
+ # 添加日志
138
+ logging.info(f"HTML处理前长度: {original_length}, 处理后长度: {processed_length}")
139
+
140
+ # 如果处理后文本太短,可能是提取失败
141
+ if processed_length < 100 and original_length > 1000:
142
+ logging.warning(f"提取的内容异常短: {processed_length} 字符")
143
+ return ""
144
+
145
+ return text
146
+
147
+ except Exception as e:
148
+ logging.error(f"提取主要内容时出错: {str(e)}")
149
+ return ""
150
+
151
+ def extract_metadata(self, html: str) -> dict:
152
+ """提取HTML中的元数据"""
153
+ try:
154
+ soup = BeautifulSoup(html, 'html.parser')
155
+ metadata = {
156
+ 'title': '',
157
+ 'description': '',
158
+ 'keywords': ''
159
+ }
160
+
161
+ # 更安全的标题提取
162
+ title = ''
163
+ if soup.title and soup.title.string:
164
+ title = soup.title.string.strip()
165
+ else:
166
+ # 尝试从h1标签提取标题
167
+ h1 = soup.find('h1')
168
+ if h1:
169
+ title = h1.get_text().strip()
170
+
171
+ # 如果还是没有标题,使用默认值
172
+ metadata['title'] = title if title else "未知标题"
173
+
174
+ # 提取meta描述
175
+ meta_desc = soup.find('meta', attrs={'name': ['description', 'Description']})
176
+ if meta_desc:
177
+ metadata['description'] = meta_desc.get('content', '').strip()
178
+
179
+ # 提取meta关键词
180
+ meta_keywords = soup.find('meta', attrs={'name': ['keywords', 'Keywords']})
181
+ if meta_keywords:
182
+ metadata['keywords'] = meta_keywords.get('content', '').strip()
183
+
184
+ # 确保所有字段都有值
185
+ metadata = {k: v if v else '未知' for k, v in metadata.items()}
186
+
187
+ return metadata
188
+
189
+ except Exception as e:
190
+ logging.error(f"提取元数据时出错: {str(e)}")
191
+ return {
192
+ 'title': '未知标题',
193
+ 'description': '未知描述',
194
+ 'keywords': '未知关键词'
195
+ }
src/core/plan_generator.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict
2
+ from src.api.llm_api import LLMInterface
3
+ import logging
4
+
5
+ class PlanGenerator:
6
+ def __init__(self, llm: LLMInterface):
7
+ self.llm = llm
8
+
9
+ def generate_plan(self, query: str, context: List[Dict]) -> Dict:
10
+ """生成旅行计划"""
11
+ # 确保查询包含香港关键词
12
+ if "Hong Kong" not in query:
13
+ query = f"Hong Kong Tourism {query}"
14
+
15
+ logging.info(f"Generating plan for query: {query}")
16
+ logging.info(f"Number of reference documents: {len(context)}")
17
+
18
+ # 构建提示词
19
+ prompt = self._build_prompt(query, context)
20
+
21
+ # 生成计划
22
+ plan = self.llm.generate_travel_plan(query, context)
23
+
24
+ # 记录来源
25
+ sources = []
26
+ for doc in context:
27
+ if doc.get('url'):
28
+ sources.append({
29
+ 'url': doc['url'],
30
+ 'title': doc.get('title', 'Unknown Title'),
31
+ 'relevance_score': doc.get('relevance_score', 0)
32
+ })
33
+
34
+ return {
35
+ 'query': query,
36
+ 'plan': plan,
37
+ 'sources': sources
38
+ }
39
+
40
+ def _build_prompt(self, query: str, context: List[Dict]) -> str:
41
+ """构建提示词"""
42
+ # 提取查询中的关键信息
43
+ days = self._extract_days(query)
44
+
45
+ prompt = f"""Please create a detailed Hong Kong travel plan based on the following information.
46
+
47
+ User Needs: {query}
48
+
49
+ Reference Materials:
50
+ """
51
+ # 添加上下文信息
52
+ for i, doc in enumerate(context, 1):
53
+ prompt += f"\nSource {i}:\n{doc['passage']}\n"
54
+
55
+ prompt += f"""
56
+ Please provide a detailed itinerary for a {days}-day trip to Hong Kong, including the following content:
57
+
58
+ 1. Itinerary Overview:
59
+ - Overall itinerary arrangement
60
+ - Key attractions introduction
61
+ - Time allocation suggestions
62
+
63
+ 2. Daily detailed itinerary:
64
+ - Morning activities and attractions
65
+ - Afternoon activities and attractions
66
+ - Evening activities and attractions
67
+ - Specific time allocation
68
+ - Traffic suggestions
69
+
70
+ 3. Traffic Suggestions:
71
+ - Return transportation plan
72
+ - City transportation suggestions
73
+ - Traffic card purchase suggestions
74
+ - Practical traffic APP recommendations
75
+
76
+ 4. Accommodation Recommendations:
77
+ - Recommended area
78
+ - Specific hotel suggestions
79
+ - Booking considerations
80
+
81
+ 5. Food Recommendations:
82
+ - Specialty restaurants
83
+ - Recommended restaurants
84
+ - Snack street recommendations
85
+
86
+ 6. Practical Tips:
87
+ - Weather suggestions
88
+ - Clothing suggestions
89
+ - Essential items
90
+ - Considerations
91
+ - Consumer budget
92
+
93
+ Please ensure:
94
+ 1. The itinerary is reasonable, considering the distance between attractions and the time for visiting
95
+ 2. Provide specific time allocation
96
+ 3. Include practical local suggestions
97
+ 4. Consider actual conditions (e.g., traffic time, attraction opening hours, etc.)
98
+ 5. Provide detailed traffic guidance
99
+
100
+ Please return the travel plan content directly, without any other explanation."""
101
+
102
+ return prompt
103
+
104
+ def _extract_days(self, query: str) -> int:
105
+ """从查询中提取天数"""
106
+ import re
107
+ # 匹配常见的天数表达方式
108
+ patterns = [
109
+ r'(\d+)\s*[天日]',
110
+ r'(\d+)\s*-*\s*days?',
111
+ r'(\d+)\s*-*\s*d'
112
+ ]
113
+
114
+ for pattern in patterns:
115
+ match = re.search(pattern, query.lower())
116
+ if match:
117
+ return int(match.group(1))
118
+
119
+ # 默认返回3天
120
+ return 3
src/core/ranking.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import faiss
2
+ import numpy as np
3
+ from typing import List, Dict
4
+ from .embeddings import EmbeddingModel
5
+ from .reranker import Reranker
6
+ from sentence_transformers import SentenceTransformer
7
+ import logging
8
+ from sklearn.metrics.pairwise import cosine_similarity
9
+ import torch
10
+
11
+ class RankingSystem:
12
+ def __init__(self,
13
+ embedding_model: EmbeddingModel = None,
14
+ reranker: Reranker = None):
15
+ self.embedding_model = embedding_model or EmbeddingModel()
16
+ self.reranker = reranker or Reranker()
17
+ self.index = None
18
+ self.passages = None
19
+ self.embedding_cache = {}
20
+
21
+ def build_index(self, passages: List[Dict]):
22
+ """构建FAISS索引"""
23
+ self.passages = passages
24
+ texts = [p['passage'] for p in passages]
25
+
26
+ if not texts:
27
+ logging.warning("没有文本需要编码")
28
+ return
29
+
30
+ embeddings = self.embedding_model.encode(texts)
31
+
32
+ if embeddings is None or not hasattr(embeddings, 'shape'):
33
+ logging.error("编码结果为空或格式不正确")
34
+ return
35
+
36
+ dimension = embeddings.shape[1]
37
+ self.index = faiss.IndexFlatIP(dimension)
38
+ self.index.add(embeddings.astype('float32'))
39
+
40
+ def initial_ranking(self, query: str, passages: List[Dict], initial_top_k: int = 10) -> List[Dict]:
41
+ """对文档进行初始排序并返回前K个结果"""
42
+ # 确保输入格式正确
43
+ if not isinstance(passages[0], dict):
44
+ passages = [{'passage': p} for p in passages]
45
+
46
+ # 使用缓存的嵌入
47
+ texts = [p['passage'] for p in passages]
48
+ embeddings = []
49
+
50
+ for text in texts:
51
+ if text in self.embedding_cache:
52
+ embeddings.append(self.embedding_cache[text])
53
+ else:
54
+ embedding = self.embedding_model.encode([text])[0]
55
+ self.embedding_cache[text] = embedding
56
+ embeddings.append(embedding)
57
+
58
+ embeddings = np.array(embeddings)
59
+
60
+ # 批量计算相似度
61
+ query_embedding = self.embedding_model.encode([query])[0]
62
+ similarities = np.dot(embeddings, query_embedding)
63
+
64
+ # 快速排序
65
+ indices = np.argsort(similarities)[::-1][:initial_top_k]
66
+
67
+ ranked_passages = []
68
+ for idx in indices:
69
+ passage = passages[idx].copy()
70
+ passage['retrieval_score'] = float(similarities[idx])
71
+ ranked_passages.append(passage)
72
+
73
+ return ranked_passages
74
+
75
+ def rerank(self, query: str, initial_ranked: List[Dict], final_top_k: int = 3) -> List[Dict]:
76
+ """使用重排序器进行重排序"""
77
+ # 使用重排序器
78
+ reranked = self.reranker.rerank(query, initial_ranked)
79
+
80
+ # 计算最终分数(调整权重)
81
+ for passage in reranked:
82
+ # 增加相关性权重
83
+ passage['final_score'] = (
84
+ 0.3 * passage['retrieval_score'] +
85
+ 0.7 * passage['rerank_score']
86
+ )
87
+
88
+ # 按最终分数排序
89
+ final_ranked = sorted(
90
+ reranked,
91
+ key=lambda x: x['final_score'],
92
+ reverse=True
93
+ )
94
+
95
+ return final_ranked[:final_top_k]
96
+
97
+ def retrieve(self, query: str, passages: List[Dict]) -> List[Dict]:
98
+ """
99
+ 检索并排序文档
100
+
101
+ Args:
102
+ query: 查询字符串
103
+ passages: 待检索的文档列表
104
+
105
+ Returns:
106
+ List[Dict]: 经过排序的文档列表
107
+ """
108
+ # 1. 首先进行初始排序
109
+ initial_results = self.initial_ranking(query, passages, initial_top_k=10)
110
+
111
+ # 2. 然后进行重排序
112
+ final_results = self.rerank(query, initial_results, final_top_k=3)
113
+
114
+ return final_results
src/core/reranker.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict
2
+ from FlagEmbedding import FlagReranker
3
+ import logging
4
+ import torch
5
+ import os
6
+ from sentence_transformers import CrossEncoder
7
+
8
+ class Reranker:
9
+ def __init__(self, model_path="BAAI/bge-reranker-large"):
10
+ try:
11
+ self.model = FlagReranker(
12
+ model_path,
13
+ use_fp16=True,
14
+ device="cuda" if torch.cuda.is_available() else "cpu"
15
+ )
16
+ logging.info(f"成功加载重排序模型 {model_path} 到 {'cuda' if torch.cuda.is_available() else 'cpu'} 设备")
17
+ except Exception as e:
18
+ logging.error(f"加载重排序模型失败: {str(e)}")
19
+ raise
20
+
21
+ def rerank(self, query: str, passages: List[Dict]) -> List[Dict]:
22
+ """
23
+ 对文档进行重排序
24
+ """
25
+ try:
26
+ # 准备文本列表
27
+ texts = [p['passage'] for p in passages]
28
+
29
+ # 执行重排序
30
+ scores = self.model.compute_score([[query, text] for text in texts])
31
+
32
+ # 将分数添加到原始字典中
33
+ for passage, score in zip(passages, scores):
34
+ passage['rerank_score'] = float(score)
35
+
36
+ # 按重排序分数排序
37
+ reranked = sorted(passages, key=lambda x: x['rerank_score'], reverse=True)
38
+
39
+ return reranked
40
+
41
+ except Exception as e:
42
+ logging.error(f"重排序过程中出错: {str(e)}")
43
+ # 如果重排序失败,返回原始排序
44
+ return passages
src/retrieval/__pycache__/base.cpython-310.pyc ADDED
Binary file (842 Bytes). View file
 
src/retrieval/__pycache__/graph_rag.cpython-310.pyc ADDED
Binary file (2.48 kB). View file
 
src/retrieval/__pycache__/memo_rag.cpython-310.pyc ADDED
Binary file (2.62 kB). View file
 
src/retrieval/base.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import List, Dict
3
+
4
+ class BaseRetriever(ABC):
5
+ """检索策略的基类"""
6
+
7
+ @abstractmethod
8
+ def retrieve(self, query: str, context: List[Dict]) -> List[Dict]:
9
+ """执行检索"""
10
+ pass
11
+
12
+ @abstractmethod
13
+ def init_retriever(self, config: Dict):
14
+ """初始化检索器"""
15
+ pass
src/retrieval/graph_rag.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import BaseRetriever
2
+ from typing import Dict, List
3
+ import networkx as nx
4
+ # import spacy # 暂时注释掉
5
+ from sklearn.cluster import AgglomerativeClustering
6
+ import numpy as np
7
+ import logging
8
+
9
+ class GraphRAG(BaseRetriever):
10
+ def __init__(self, config: Dict):
11
+ self.config = config
12
+ self.graph = nx.Graph()
13
+ # self.nlp = spacy.load("zh_core_web_sm") # 暂时注释掉
14
+ self.init_retriever(config)
15
+
16
+ def init_retriever(self, config: Dict):
17
+ self.working_dir = config['retrieval_settings']['methods'][2]['model_settings']['working_dir']
18
+ self.graph_file = f"{self.working_dir}/graph.graphml"
19
+
20
+ def retrieve(self, query: str, context: List[Dict]) -> List[Dict]:
21
+ # 简单实现:基于关键词匹配的检索
22
+ scored_docs = []
23
+ for doc in context:
24
+ # 简单计算query中的词在文档中出现的次数作为分数
25
+ score = sum(1 for word in query.split() if word in doc['passage'])
26
+ doc_copy = doc.copy()
27
+ doc_copy['graph_score'] = float(score)
28
+ scored_docs.append(doc_copy)
29
+
30
+ return sorted(scored_docs, key=lambda x: x['graph_score'], reverse=True)
31
+
32
+ def _build_graph(self, context: List[Dict]):
33
+ """简化版本的图构建"""
34
+ # 仅使用简单的词频统计
35
+ for doc in context:
36
+ text = doc['passage']
37
+ words = text.split()
38
+ # 相邻词之间建立边
39
+ for i in range(len(words)-1):
40
+ w1, w2 = words[i], words[i+1]
41
+ if not self.graph.has_edge(w1, w2):
42
+ self.graph.add_edge(w1, w2, weight=1)
43
+ else:
44
+ self.graph[w1][w2]['weight'] += 1
45
+
46
+ def _calculate_graph_score(self, query_words: List[str], doc: Dict) -> float:
47
+ """简化版本的图分数计算"""
48
+ score = 0.0
49
+ doc_words = doc['passage'].split()
50
+
51
+ for q_word in query_words:
52
+ for d_word in doc_words:
53
+ if self.graph.has_edge(q_word, d_word):
54
+ score += self.graph[q_word][d_word]['weight']
55
+
56
+ return score if score > 0 else 0.0
src/retrieval/memo_rag.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import BaseRetriever
2
+ from typing import Dict, List
3
+ import torch
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ from sentence_transformers import SentenceTransformer
6
+ import logging
7
+ import os
8
+
9
+ class MemoRAG(BaseRetriever):
10
+ def __init__(self, config: Dict):
11
+ self.config = config
12
+ self.init_retriever(config)
13
+
14
+ def init_retriever(self, config: Dict):
15
+ memo_config = config['retrieval_settings']['methods'][1]['model_settings']
16
+
17
+ try:
18
+ # 使用本地量化模型
19
+ local_model_path = "/root/.cache/modelscope/hub/MaxLeton13/chatglm3-6B-32k-int4"
20
+
21
+ logging.info(f"加载本地模型: {local_model_path}")
22
+ self.model = AutoModelForCausalLM.from_pretrained(
23
+ local_model_path,
24
+ device_map="auto",
25
+ trust_remote_code=True
26
+ )
27
+ self.tokenizer = AutoTokenizer.from_pretrained(
28
+ local_model_path,
29
+ trust_remote_code=True
30
+ )
31
+
32
+ # 初始化向量检索模型
33
+ logging.info(f"加载向量检索模型: {memo_config['ret_model']}")
34
+ self.embedding_model = SentenceTransformer(
35
+ memo_config['ret_model'],
36
+ device="cuda" if torch.cuda.is_available() else "cpu"
37
+ )
38
+
39
+ # 设置缓存目录
40
+ self.cache_dir = memo_config['cache_dir']
41
+ os.makedirs(self.cache_dir, exist_ok=True)
42
+
43
+ except Exception as e:
44
+ logging.error(f"初始化MemoRAG失败: {str(e)}")
45
+ raise
46
+
47
+ def retrieve(self, query: str, context: List[Dict]) -> List[Dict]:
48
+ try:
49
+ # 使用向量检索进行初步筛选
50
+ query_embedding = self.embedding_model.encode(query)
51
+
52
+ # 计算文档嵌入
53
+ docs_text = [doc['passage'] for doc in context]
54
+ docs_embeddings = self.embedding_model.encode(docs_text)
55
+
56
+ # 计算相似度
57
+ similarities = torch.nn.functional.cosine_similarity(
58
+ torch.tensor(query_embedding).unsqueeze(0),
59
+ torch.tensor(docs_embeddings),
60
+ dim=1
61
+ )
62
+
63
+ # 为每个文档添加分数
64
+ scored_docs = []
65
+ for doc, score in zip(context, similarities):
66
+ doc_copy = doc.copy()
67
+ doc_copy['memory_score'] = float(score)
68
+ scored_docs.append(doc_copy)
69
+
70
+ # 按分数排序
71
+ return sorted(scored_docs, key=lambda x: x['memory_score'], reverse=True)
72
+
73
+ except Exception as e:
74
+ logging.error(f"MemoRAG检索失败: {str(e)}")
75
+ # 如果检索失败,返回原始文档列表
76
+ return context
src/utils/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .helpers import load_config, setup_logging
2
+
3
+ __all__ = [
4
+ 'load_config',
5
+ 'setup_logging',
6
+ ]
src/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (246 Bytes). View file
 
src/utils/__pycache__/helpers.cpython-310.pyc ADDED
Binary file (3 kB). View file
 
src/utils/helpers.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import yaml
3
+ from pathlib import Path
4
+ from typing import Dict, Any, Tuple
5
+
6
+ def load_config(config_path: str = "/root/travel_rag/config/config.yaml") -> Dict[str, Any]:
7
+ """
8
+ 从配置文件加载配置
9
+
10
+ Args:
11
+ config_path: 配置文件路径,默认为 "/root/travel_rag/config.yaml"
12
+
13
+ Returns:
14
+ 配置字典
15
+ """
16
+ config_path = Path(config_path)
17
+ if not config_path.exists():
18
+ raise FileNotFoundError(f"配置文件未找到: {config_path}")
19
+
20
+ with open(config_path, 'r', encoding='utf-8') as f:
21
+ return yaml.safe_load(f)
22
+
23
+ def setup_logging(
24
+ log_level: str = "INFO",
25
+ log_file: str = None
26
+ ) -> None:
27
+ """
28
+ 设置日志配置
29
+
30
+ Args:
31
+ log_level: 日志级别,默认为 "INFO"
32
+ log_file: 日志文件路径,默认为 None(仅控制台输出)
33
+ """
34
+ # 设置日志格式
35
+ log_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
36
+
37
+ # 配置根日志记录器
38
+ logging.basicConfig(
39
+ level=getattr(logging, log_level.upper()),
40
+ format=log_format
41
+ )
42
+
43
+ # 如果指定了日志文件,添加文件处理器
44
+ if log_file:
45
+ file_handler = logging.FileHandler(log_file)
46
+ file_handler.setFormatter(logging.Formatter(log_format))
47
+ logging.getLogger().addHandler(file_handler)
48
+
49
+ def setup_proxy(proxy_config_path: str = "/root/clash/config.yaml") -> Tuple[Dict[str, str], bool]:
50
+ """
51
+ 设置系统代理并返回代理配置和代理可用性状态
52
+
53
+ Args:
54
+ proxy_config_path: 代理配置文件路径(应为字符串类型)
55
+
56
+ Returns:
57
+ Tuple[Dict[str, str], bool]: (代理配置字典, 代理是否可用)
58
+ """
59
+ import os
60
+ import logging
61
+ import requests
62
+ from requests.exceptions import RequestException
63
+
64
+ logger = logging.getLogger(__name__)
65
+
66
+ # 设置默认代理地址
67
+ proxy_url = 'http://127.0.0.1:8880'
68
+
69
+ # 如果存在配置文件,从配置文件读取
70
+ if os.path.exists(proxy_config_path):
71
+ try:
72
+ config = load_config(proxy_config_path)
73
+ # 根据实际配置文件结构调整
74
+ proxy_url = config.get('proxy_url', proxy_url)
75
+ logger.info(f"已从配置文件加载代理设置: {proxy_url}")
76
+ except Exception as e:
77
+ logger.warning(f"加载代理配置失败: {e},使用默认配置")
78
+
79
+ # 设置环境变量
80
+ os.environ['HTTP_PROXY'] = proxy_url
81
+ os.environ['HTTPS_PROXY'] = proxy_url
82
+
83
+ proxies = {
84
+ 'http': proxy_url,
85
+ 'https': proxy_url
86
+ }
87
+
88
+ # 测试代理是否可用
89
+ try:
90
+ response = requests.get('https://www.google.com',
91
+ proxies=proxies,
92
+ timeout=5,
93
+ verify=False) # 添加 verify=False 避免证书问题
94
+ proxy_available = response.status_code == 200
95
+ if proxy_available:
96
+ logger.info("代理服务器可用")
97
+ else:
98
+ logger.warning(f"代理服务器响应异常,状态码: {response.status_code}")
99
+ except RequestException as e:
100
+ logger.warning(f"代理服务器连接失败: {e}")
101
+ proxy_available = False
102
+
103
+ logger.info(f"代理设置完成: {proxies}, 可用状态: {proxy_available}")
104
+ return proxies, proxy_available
src/utils/neo4j_helper.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from neo4j import GraphDatabase
2
+ from typing import List, Dict, Any
3
+ import logging
4
+
5
+ class Neo4jConnection:
6
+ def __init__(self, uri: str = "bolt://localhost:7687",
7
+ user: str = "neo4j",
8
+ password: str = "your_password"):
9
+ """初始化 Neo4j 连接"""
10
+ try:
11
+ self.driver = GraphDatabase.driver(uri, auth=(user, password))
12
+ logging.info("Neo4j 连接成功")
13
+ except Exception as e:
14
+ logging.error(f"Neo4j 连接失败: {str(e)}")
15
+ raise
16
+
17
+ def close(self):
18
+ """关闭连接"""
19
+ if self.driver:
20
+ self.driver.close()
21
+
22
+ def run_query(self, query: str, parameters: Dict[str, Any] = None) -> List[Dict]:
23
+ """执行 Cypher 查询"""
24
+ try:
25
+ with self.driver.session() as session:
26
+ result = session.run(query, parameters or {})
27
+ return [record.data() for record in result]
28
+ except Exception as e:
29
+ logging.error(f"执行查询失败: {str(e)}")
30
+ raise
31
+
32
+ def create_node(self, label: str, properties: Dict[str, Any]) -> Dict:
33
+ """创建节点"""
34
+ query = f"""
35
+ CREATE (n:{label} $properties)
36
+ RETURN n
37
+ """
38
+ return self.run_query(query, {"properties": properties})
39
+
40
+ def create_relationship(self, start_node_label: str, start_node_props: Dict,
41
+ end_node_label: str, end_node_props: Dict,
42
+ relationship_type: str, relationship_props: Dict = None) -> Dict:
43
+ """创建关系"""
44
+ query = f"""
45
+ MATCH (a:{start_node_label}), (b:{end_node_label})
46
+ WHERE a.id = $start_props.id AND b.id = $end_props.id
47
+ CREATE (a)-[r:{relationship_type} $rel_props]->(b)
48
+ RETURN a, r, b
49
+ """
50
+ params = {
51
+ "start_props": start_node_props,
52
+ "end_props": end_node_props,
53
+ "rel_props": relationship_props or {}
54
+ }
55
+ return self.run_query(query, params)
56
+
57
+ def get_node(self, label: str, properties: Dict[str, Any]) -> Dict:
58
+ """获取节点"""
59
+ query = f"""
60
+ MATCH (n:{label})
61
+ WHERE n.id = $properties.id
62
+ RETURN n
63
+ """
64
+ return self.run_query(query, {"properties": properties})
65
+
66
+ # 使用示例:
67
+ if __name__ == "__main__":
68
+ # 创建连接
69
+ neo4j = Neo4jConnection()
70
+
71
+ try:
72
+ # 创建示例节点
73
+ node = neo4j.create_node("Place", {
74
+ "id": "1",
75
+ "name": "香港迪士尼乐园",
76
+ "type": "景点"
77
+ })
78
+ print("创建节点成功:", node)
79
+
80
+ finally:
81
+ # 关闭连接
82
+ neo4j.close()