import gradio as gr import sys import os from pathlib import Path from typing import List # 添加项目根目录到 Python 路径 sys.path.append(str(Path(__file__).parent)) from src.api.search_api import BochaSearch from src.core.document_processor import DocumentProcessor from src.core.ranking import RankingSystem from src.core.plan_generator import PlanGenerator from src.core.embeddings import EmbeddingModel from src.core.reranker import Reranker from src.api.llm_api import DeepseekInterface, LLMInterface, OpenAIInterface from src.utils.helpers import load_config import logging # 设置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class TravelRAGSystem: def __init__(self): self.config = load_config("config/config.yaml") self.llm_instances = {} # 存储不同provider的LLM实例 # 固定使用 standard 检索方法 self.retrieval_method = "standard" self.init_llm_instances() self.init_components() def init_components(self): # 获取默认提供商的配置 default_provider = self.config['llm_settings']['default_provider'] provider_config = next( (p for p in self.config['llm_settings']['providers'] if p['name'] == default_provider), None ) if not provider_config: raise ValueError(f"未找到默认提供商 {default_provider} 的配置") # 初始化LLM实例 self.llm = self.init_llm(provider_config['name'], provider_config['model']) self.search_engine = BochaSearch( api_key=self.config['bocha_api_key'], base_url=self.config['bocha_base_url'] ) self.doc_processor = DocumentProcessor(self.llm) # 初始化嵌入模型 - 使用 Hugging Face 模型 ID try: self.embedding_model = EmbeddingModel( model_name="BAAI/bge-m3" ) logger.info("成功加载嵌入模型") except Exception as e: logger.error(f"加载嵌入模型失败: {str(e)}") raise # 初始化重排序器 - 使用 Hugging Face 模型 ID try: self.reranker = Reranker( model_path="BAAI/bge-reranker-large" ) logger.info("成功加载重排序模型") except Exception as e: logger.error(f"加载重排序模型失败: {str(e)}") raise self.ranking_system = RankingSystem(self.embedding_model, self.reranker) self.plan_generator = PlanGenerator(self.llm) def init_llm(self, provider: str, model: str): if provider == "openai": return OpenAIInterface( api_key=self.config['openai_api_key'], model=model ) elif provider == "deepseek": return DeepseekInterface( api_key=self.config['deepseek_api_key'], base_url=next( p['base_url'] for p in self.config['llm_settings']['providers'] if p['name'] == 'deepseek' ), model=model ) else: raise ValueError(f"不支持的LLM提供商: {provider}") def init_llm_instances(self): """初始化所有启用的LLM实例""" for provider in self.config['llm_settings']['providers']: if provider.get('enabled', False): try: if provider['name'] == "openai": self.llm_instances['openai'] = OpenAIInterface( api_key=self.config['openai_api_key'], model=provider['model'] ) else: self.llm_instances['deepseek'] = DeepseekInterface( api_key=self.config['deepseek_api_key'], base_url=provider['base_url'], model=provider['model'] ) logging.info(f"成功初始化 {provider['name']} LLM") except Exception as e: logging.error(f"初始化 {provider['name']} LLM 失败: {str(e)}") def get_llm(self, provider_name: str = None) -> LLMInterface: """获取指定的LLM实例""" if not provider_name: provider_name = self.config['llm_settings']['default_provider'] if provider_name not in self.llm_instances: raise ValueError(f"未找到或未启用的LLM提供商: {provider_name}") return self.llm_instances[provider_name] def process_query( self, query: str, days: int, llm_provider: str, llm_model: str, enable_images: bool = True, retrieval_method: str = None ) -> tuple: try: # 如果指定了新的检索方法,则切换 if retrieval_method and retrieval_method != self.retrieval_method: self.set_retrieval_method(retrieval_method) # 确保LLM提供商存在 if llm_provider not in self.llm_instances: raise ValueError(f"LLM提供商 {llm_provider} 未启用或不可用,将使用默认提供商") current_llm = self.llm_instances[llm_provider] self.doc_processor = DocumentProcessor(current_llm) self.plan_generator = PlanGenerator(current_llm) # 确保查询包含天数 if days > 0: query = f"{query} {days} days" # 执行搜索 logger.info(f"执行搜索: {query}") search_results = self.search_engine.search(query) logger.info(f"搜索结果: {search_results}") # 处理文档 passages = self.doc_processor.process_documents(search_results) logger.info(f"处理后的文档: {passages}") # 使用当前检索器进行检索和排序 if hasattr(self, 'retriever'): final_ranked = self.retriever.retrieve(query, passages) else: # 使用默认的排序系统 initial_ranked = self.ranking_system.initial_ranking(query, passages) final_ranked = self.ranking_system.rerank(query, initial_ranked) # 生成计划 final_plan = self.plan_generator.generate_plan(query, final_ranked) logger.info(f"生成的计划: {final_plan}") # 修改准备参考来源的部分 # 创建表格的表头 table_header = "| Reference URL | Relevance Score | Retrieval Score | Rerank Score |\n| --- | --- | --- | --- |" # 准备表格行 table_rows = [] for doc in final_ranked: # 如果标题为空,使用URL作为标题 title = doc.get('title', '').strip() if not title: from urllib.parse import urlparse domain = urlparse(doc['url']).netloc title = domain # 创建表格行 row = ( f"| [{title}]({doc['url']}) | " f"{doc.get('final_score', 0):.3f} | " f"{doc.get('retrieval_score', 0):.3f} | " f"{doc.get('rerank_score', 0):.3f} |" ) table_rows.append(row) # 组合表格 sources = table_header + "\n" + "\n".join(table_rows) logger.info(f"参考来源: {sources}") # 修改图片展示部分 image_html = "" if enable_images: try: # 增加搜索数量,因为要过滤 images = self.search_engine.search_images(query, count=8) valid_images = [] if images: # 过滤图片 for img in images: img_url = img.get('url', '') if img_url and self.verify_image_url(img_url): valid_images.append(img_url) if len(valid_images) >= 3: # 只需要3张有效图片 break if valid_images: # 如果有有效图片 image_html = """
""" for img_url in valid_images: image_html += f"""
旅行相关图片
""" image_html += "
" except Exception as e: logger.warning(f"获取图片时出现错误: {str(e)}") # 美化计划内容的展示 plan_content = final_plan['plan'] # 替换标记符号 replacements = { '###': '', # 移除三个# '##': '', # 移除两个# '# ': '', # 移除单个# '**': '', # 移除所有** } for old, new in replacements.items(): plan_content = plan_content.replace(old, new) # 处理标题和段落 paragraphs = plan_content.split('\n') formatted_paragraphs = [] for p in paragraphs: p = p.strip() if not p: continue if "Tour Overview" in p: # 主标题样式 formatted_paragraphs.append( f'

' f'📍 {p}

' ) elif ":" in p and any(x in p for x in ["Date", "Destination", "Key Attractions"]): # 关键信息样式,加粗值部分 key, value = p.split(":", 1) formatted_paragraphs.append( f'
' f'{key}:' f'{value.strip()}' f'
' ) elif "Daily Itinerary" in p: # 主标题样式 formatted_paragraphs.append( f'

' f'🕒️ {p}

' ) elif " - " in p: # 时间段标题 # 子标题样式 formatted_paragraphs.append( f'

' f'🕒 {p}

' ) elif p.startswith("Location") or p.startswith("Activity") or p.startswith("Transportation") or p.startswith("Specific Guidance"): # 信息样式 key, value = p.split(":", 1) icon = { "Location": "📍", "Activity": "🎯", "Transportation": "🚇", "Specific Guidance": "🗺️" }.get(key, "•") formatted_paragraphs.append( f'
' f'{icon}' f'{value.strip()}' f'
' ) else: # 普通段落样式 formatted_paragraphs.append( f'

{p}

' ) plan_content = '\n'.join(formatted_paragraphs) # 将所有内容包装在一个暗色主题的容器中 final_output = f"""
{plan_content}
""" return final_output, sources, image_html except Exception as e: logger.error(f"Error processing query: {str(e)}") return f"Sorry, an error occurred while processing your request: {str(e)}", "", "" def verify_image_url(self, url: str) -> bool: """验证图片URL是否可访问且符合要求""" try: import requests from PIL import Image import io import numpy as np from PIL import ImageDraw, ImageFont # 获取图片 response = requests.get(url, timeout=3) if response.status_code != 200: return False # 检查内容类型 content_type = response.headers.get('content-type', '') if 'image' not in content_type.lower(): return False # 读取图片 img = Image.open(io.BytesIO(response.content)) # 1. 检查图片尺寸 width, height = img.size if width < 300 or height < 300: # 过滤掉太小的图片 return False # 2. 检查宽高比 aspect_ratio = width / height if aspect_ratio < 0.5 or aspect_ratio > 2.0: # 过滤掉比例不合适的图片 return False # 3. 转换为numpy数组进行分析 img_array = np.array(img) # 4. 检查图片是否过于单调(可能是纯文字图) if len(img_array.shape) == 3: # 确保是彩色图片 std = np.std(img_array) if std < 30: # 标准差太小说明图片太单调 return False # 5. 检测文字区域(简单实现) # 转换为灰度图 if img.mode != 'L': img_gray = img.convert('L') else: img_gray = img # 计算边缘密度 from PIL import ImageFilter edges = img_gray.filter(ImageFilter.FIND_EDGES) edge_density = np.mean(np.array(edges)) # 如果边缘密度太高,可能包含大量文字 if edge_density > 30: return False # 6. 检查图片是否过于饱和(可能是广告图) if len(img_array.shape) == 3: hsv = img.convert('HSV') saturation = np.array(hsv)[:,:,1] if np.mean(saturation) > 200: # 饱和度过高 return False return True except Exception as e: logger.warning(f"图片验证失败: {str(e)}") return False def _format_images_html(self, images: List[str]) -> str: """格式化图片HTML展示""" if not images: return "" # 使用flex布局来展示图片 html = """
""" for img_url in images: # 添加图片容器和加载失败处理 html += f"""
""" html += "
" # 添加调试日志 logger.info(f"生成的图片HTML: {html[:200]}...") # 只打印前200个字符 return html def set_retrieval_method(self, method: str): """切换检索方法""" if method not in ["standard"]: raise ValueError(f"不支持的检索方法: {method}") self.retrieval_method = method # 根据方法初始化对应的检索器 if method == "standard": self.retriever = self.ranking_system def create_interface(): system = TravelRAGSystem() # 获取已启用的提供商列表 enabled_providers = [ provider['name'] for provider in system.config['llm_settings']['providers'] if provider['enabled'] ] # 创建提供商和模型的映射 provider_models = { provider['name']: provider['models'] for provider in system.config['llm_settings']['providers'] if provider['enabled'] } # 创建界面并设置自定义CSS css = """ .gradio-container { font-family: "PingFang SC", "Microsoft YaHei", sans-serif; } /* 针对所有英文文本 */ [class*="message-"] { font-family: 'Times New Roman', serif !important; } /* 确保英文和数字使用 Times New Roman */ .gradio-container *:not(:lang(zh)) { font-family: 'Times New Roman', serif !important; } @keyframes spin { 0% { transform: rotate(0deg); } 100% { transform: rotate(360deg); } } /* 隐藏数字输入框的上下箭头 */ input[type="number"]::-webkit-inner-spin-button, input[type="number"]::-webkit-outer-spin-button { -webkit-appearance: none; margin: 0; } input[type="number"] { -moz-appearance: textfield; } /* 隐藏默认的 processing 信息和箭头 */ .progress-text, .meta-text-center, .progress-container { display: none !important; } /* 修改加载动画样式 */ .loading { display: flex; align-items: center; justify-content: center; gap: 8px; font-size: 1.2em; color: rgb(192, 192, 255); } .loading::before { content: '🌍'; display: inline-block; animation: spin 2s linear infinite; filter: brightness(1.5); /* 让地球图标更亮 */ } /* 调整 Gradio 默认加载动画的位置 */ .progress-text { display: block !important; order: 3; margin-top: 8px; opacity: 0.7; } .meta-text-center { display: block !important; } /* 确保加载容器使用 flex 布局 */ .loading-container { display: flex; flex-direction: column; align-items: center; } /* 隐藏滑块右侧的上下箭头 */ .num-input-plus, .num-input-minus { display: none !important; } /* 隐藏所有滚动箭头 */ .scroll-hide, .output-markdown, .output-text, .markdown-text, .prose, .gr-box, .gr-panel { -ms-overflow-style: none !important; scrollbar-width: none !important; overflow-y: hidden !important; overflow: hidden !important; } .scroll-hide::-webkit-scrollbar, .output-markdown::-webkit-scrollbar, .output-text::-webkit-scrollbar, .markdown-text::-webkit-scrollbar, .prose::-webkit-scrollbar, .gr-box::-webkit-scrollbar, .gr-panel::-webkit-scrollbar { display: none !important; width: 0 !important; height: 0 !important; } /* 修改加载动画容器样式 */ .loading-container { overflow: hidden !important; min-height: 60px; } /* 隐藏 Gradio 默认的滚动控件 */ .wrap.svelte-byatnx, .contain.svelte-byatnx, [class*='svelte'], .gradio-container { overflow: hidden !important; overflow-y: hidden !important; } /* 禁用所有可能的滚动控件 */ ::-webkit-scrollbar { display: none !important; width: 0 !important; height: 0 !important; } /* 移除 Group 组件的默认背景 */ .custom-group { border: none !important; background: none !important; box-shadow: none !important; } .custom-group > div { border: none !important; background: none !important; box-shadow: none !important; } /* 添加图片容器样式 */ .images-container { margin-top: 20px; padding: 10px; background: #fff; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); } .images-container img { transition: transform 0.3s ease; } .images-container img:hover { transform: scale(1.05); } /* 确保图片容器可见 */ #component-13 { min-height: 200px; overflow: visible !important; } """ # 修改 JavaScript 加载状态文本 js = """ function showLoading() { document.getElementById('loading_status').innerHTML = '

Generating your personalized travel plan...

'; return ['', '']; } """ with gr.Blocks(theme=gr.themes.Soft(), css=css) as interface: gr.Markdown(""" # 🌟 Tourism Planning Assistant 🌟 Welcome to the Smart Travel Planning Assistant! Simply input your travel requirements, and we'll generate a personalized travel plan for you. ### Instructions 1. Describe your travel needs in the input box (e.g., 'One-day trip to Hong Kong Disneyland') 2. Select the number of days for your plan 3. Click the "Generate Plan" button """) with gr.Row(): with gr.Column(scale=4): llm_provider = gr.Dropdown( choices=enabled_providers, value=system.config['llm_settings']['default_provider'], label="Select LLM Provider" ) llm_model = gr.Dropdown( choices=provider_models[system.config['llm_settings']['default_provider']], label="Select Model" ) # 添加更新模型选择的函数 def update_model_choices(provider): return gr.Dropdown(choices=provider_models[provider]) # 设置提供商改变时的回调 llm_provider.change( fn=update_model_choices, inputs=[llm_provider], outputs=[llm_model] ) query_input = gr.Textbox( label="Travel Requirements", placeholder="Please enter your travel requirements, e.g.: One-day trip to Hong Kong Disneyland", lines=2 ) days_input = gr.Slider( minimum=1, maximum=7, value=1, step=1, label="Number of Days" ) # 添加显示图片的复选框 show_images = gr.Checkbox( label="Search Related Images", value=True, info="Whether to search and display related reference images" ) # 移除 memorag 和 graphrag 选项,只保留 standard retrieval_method = gr.Radio( choices=["standard"], value="standard", label="Retrieval Method", info="Choose different retrieval strategies", visible=False # 由于只有一个选项,可以直接隐藏 ) submit_btn = gr.Button("Generate Plan", variant="primary") loading_status = gr.Markdown("", elem_id="loading_status", show_label=False) # 添加图片展示区域到左侧列 images_container = gr.HTML( value="", # 确保初始值为空字符串 visible=True, label="Related Images" ) # 当复选框状态改变时更新图片区域的显示状态 show_images.change( fn=lambda x: "" if not x else "
", # 当禁用图片时返回空字符串 inputs=[show_images], outputs=[images_container] ) with gr.Column(scale=6): with gr.Tabs(): with gr.TabItem("Travel Plan"): plan_output = gr.HTML(label="Generated Travel Plan", show_label=False) with gr.TabItem("References and Evaluation"): sources_output = gr.Markdown(label="References and Evaluation", show_label=False) # 修改示例为英文 gr.Examples( examples=[ ["One-day trip to Hong Kong Disneyland", 1], ["Family trip to Hong Kong Ocean Park", 1], ["Hong Kong Shopping and Food Tour", 2], ["Hong Kong Cultural Experience Tour", 3] ], inputs=[query_input, days_input], label="Example Queries" ) def show_loading(): loading_html = "

Generating your personalized travel plan...

" return loading_html, loading_html, "", "" def process_with_images(query, days, llm_provider, llm_model, enable_images, retrieval_method): plan_html, sources_md, images_html = system.process_query( query, days, llm_provider, llm_model, enable_images, retrieval_method ) # 添加调试日志 logger.info(f"图片HTML长度: {len(images_html) if images_html else 0}") return plan_html, sources_md, images_html # 设置提交按钮事件 submit_btn.click( fn=show_loading, inputs=None, outputs=[loading_status, plan_output, sources_output, images_container] ).then( fn=process_with_images, inputs=[ query_input, days_input, llm_provider, llm_model, show_images, retrieval_method ], outputs=[plan_output, sources_output, images_container] # 确保顺序正确 ).then( fn=lambda: "", inputs=None, outputs=[loading_status] ) # 修改页脚为英文 gr.Markdown(""" ### 📝 Notes - Plan generation may take some time, please be patient - Queries should include specific locations and activity preferences - All plans are AI-generated, please adjust according to actual circumstances Powered by RAG for Tourism system © 2024 """) return interface if __name__ == "__main__": demo = create_interface() # 使用 Hugging Face Spaces 环境变量 demo.launch( server_name="0.0.0.0", server_port=7860, share=False, # Hugging Face Spaces 已经提供了公开访问 debug=False, ssr_mode=False )