import gradio as gr import sys import os from pathlib import Path from typing import List # 添加项目根目录到 Python 路径 sys.path.append(str(Path(__file__).parent)) from src.api.search_api import BochaSearch from src.core.document_processor import DocumentProcessor from src.core.ranking import RankingSystem from src.core.plan_generator import PlanGenerator from src.core.embeddings import EmbeddingModel from src.core.reranker import Reranker from src.api.llm_api import DeepseekInterface, LLMInterface, OpenAIInterface from src.utils.helpers import load_config import logging # 设置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class TravelRAGSystem: def __init__(self): self.config = load_config("config/config.yaml") self.llm_instances = {} # 存储不同provider的LLM实例 # 固定使用 standard 检索方法 self.retrieval_method = "standard" self.init_llm_instances() self.init_components() def init_components(self): # 获取默认提供商的配置 default_provider = self.config['llm_settings']['default_provider'] provider_config = next( (p for p in self.config['llm_settings']['providers'] if p['name'] == default_provider), None ) if not provider_config: raise ValueError(f"未找到默认提供商 {default_provider} 的配置") # 初始化LLM实例 self.llm = self.init_llm(provider_config['name'], provider_config['model']) self.search_engine = BochaSearch( api_key=self.config['bocha_api_key'], base_url=self.config['bocha_base_url'] ) self.doc_processor = DocumentProcessor(self.llm) # 初始化嵌入模型 - 使用 Hugging Face 模型 ID try: self.embedding_model = EmbeddingModel( model_name="BAAI/bge-m3" ) logger.info("成功加载嵌入模型") except Exception as e: logger.error(f"加载嵌入模型失败: {str(e)}") raise # 初始化重排序器 - 使用 Hugging Face 模型 ID try: self.reranker = Reranker( model_path="BAAI/bge-reranker-large" ) logger.info("成功加载重排序模型") except Exception as e: logger.error(f"加载重排序模型失败: {str(e)}") raise self.ranking_system = RankingSystem(self.embedding_model, self.reranker) self.plan_generator = PlanGenerator(self.llm) def init_llm(self, provider: str, model: str): if provider == "openai": return OpenAIInterface( api_key=self.config['openai_api_key'], model=model ) elif provider == "deepseek": return DeepseekInterface( api_key=self.config['deepseek_api_key'], base_url=next( p['base_url'] for p in self.config['llm_settings']['providers'] if p['name'] == 'deepseek' ), model=model ) else: raise ValueError(f"不支持的LLM提供商: {provider}") def init_llm_instances(self): """初始化所有启用的LLM实例""" for provider in self.config['llm_settings']['providers']: if provider.get('enabled', False): try: if provider['name'] == "openai": self.llm_instances['openai'] = OpenAIInterface( api_key=self.config['openai_api_key'], model=provider['model'] ) else: self.llm_instances['deepseek'] = DeepseekInterface( api_key=self.config['deepseek_api_key'], base_url=provider['base_url'], model=provider['model'] ) logging.info(f"成功初始化 {provider['name']} LLM") except Exception as e: logging.error(f"初始化 {provider['name']} LLM 失败: {str(e)}") def get_llm(self, provider_name: str = None) -> LLMInterface: """获取指定的LLM实例""" if not provider_name: provider_name = self.config['llm_settings']['default_provider'] if provider_name not in self.llm_instances: raise ValueError(f"未找到或未启用的LLM提供商: {provider_name}") return self.llm_instances[provider_name] def process_query( self, query: str, days: int, llm_provider: str, llm_model: str, enable_images: bool = True, retrieval_method: str = None ) -> tuple: try: # 如果指定了新的检索方法,则切换 if retrieval_method and retrieval_method != self.retrieval_method: self.set_retrieval_method(retrieval_method) # 确保LLM提供商存在 if llm_provider not in self.llm_instances: raise ValueError(f"LLM提供商 {llm_provider} 未启用或不可用,将使用默认提供商") current_llm = self.llm_instances[llm_provider] self.doc_processor = DocumentProcessor(current_llm) self.plan_generator = PlanGenerator(current_llm) # 确保查询包含天数 if days > 0: query = f"{query} {days} days" # 执行搜索 logger.info(f"执行搜索: {query}") search_results = self.search_engine.search(query) logger.info(f"搜索结果: {search_results}") # 处理文档 passages = self.doc_processor.process_documents(search_results) logger.info(f"处理后的文档: {passages}") # 使用当前检索器进行检索和排序 if hasattr(self, 'retriever'): final_ranked = self.retriever.retrieve(query, passages) else: # 使用默认的排序系统 initial_ranked = self.ranking_system.initial_ranking(query, passages) final_ranked = self.ranking_system.rerank(query, initial_ranked) # 生成计划 final_plan = self.plan_generator.generate_plan(query, final_ranked) logger.info(f"生成的计划: {final_plan}") # 修改准备参考来源的部分 # 创建表格的表头 table_header = "| Reference URL | Relevance Score | Retrieval Score | Rerank Score |\n| --- | --- | --- | --- |" # 准备表格行 table_rows = [] for doc in final_ranked: # 如果标题为空,使用URL作为标题 title = doc.get('title', '').strip() if not title: from urllib.parse import urlparse domain = urlparse(doc['url']).netloc title = domain # 创建表格行 row = ( f"| [{title}]({doc['url']}) | " f"{doc.get('final_score', 0):.3f} | " f"{doc.get('retrieval_score', 0):.3f} | " f"{doc.get('rerank_score', 0):.3f} |" ) table_rows.append(row) # 组合表格 sources = table_header + "\n" + "\n".join(table_rows) logger.info(f"参考来源: {sources}") # 修改图片展示部分 image_html = "" if enable_images: try: # 增加搜索数量,因为要过滤 images = self.search_engine.search_images(query, count=8) valid_images = [] if images: # 过滤图片 for img in images: img_url = img.get('url', '') if img_url and self.verify_image_url(img_url): valid_images.append(img_url) if len(valid_images) >= 3: # 只需要3张有效图片 break if valid_images: # 如果有有效图片 image_html = """
{p}
' ) plan_content = '\n'.join(formatted_paragraphs) # 将所有内容包装在一个暗色主题的容器中 final_output = f"""Generating your personalized travel plan...
'; return ['', '']; } """ with gr.Blocks(theme=gr.themes.Soft(), css=css) as interface: gr.Markdown(""" # 🌟 Tourism Planning Assistant 🌟 Welcome to the Smart Travel Planning Assistant! Simply input your travel requirements, and we'll generate a personalized travel plan for you. ### Instructions 1. Describe your travel needs in the input box (e.g., 'One-day trip to Hong Kong Disneyland') 2. Select the number of days for your plan 3. Click the "Generate Plan" button """) with gr.Row(): with gr.Column(scale=4): llm_provider = gr.Dropdown( choices=enabled_providers, value=system.config['llm_settings']['default_provider'], label="Select LLM Provider" ) llm_model = gr.Dropdown( choices=provider_models[system.config['llm_settings']['default_provider']], label="Select Model" ) # 添加更新模型选择的函数 def update_model_choices(provider): return gr.Dropdown(choices=provider_models[provider]) # 设置提供商改变时的回调 llm_provider.change( fn=update_model_choices, inputs=[llm_provider], outputs=[llm_model] ) query_input = gr.Textbox( label="Travel Requirements", placeholder="Please enter your travel requirements, e.g.: One-day trip to Hong Kong Disneyland", lines=2 ) days_input = gr.Slider( minimum=1, maximum=7, value=1, step=1, label="Number of Days" ) # 添加显示图片的复选框 show_images = gr.Checkbox( label="Search Related Images", value=True, info="Whether to search and display related reference images" ) # 移除 memorag 和 graphrag 选项,只保留 standard retrieval_method = gr.Radio( choices=["standard"], value="standard", label="Retrieval Method", info="Choose different retrieval strategies", visible=False # 由于只有一个选项,可以直接隐藏 ) submit_btn = gr.Button("Generate Plan", variant="primary") loading_status = gr.Markdown("", elem_id="loading_status", show_label=False) # 添加图片展示区域到左侧列 images_container = gr.HTML( value="", # 确保初始值为空字符串 visible=True, label="Related Images" ) # 当复选框状态改变时更新图片区域的显示状态 show_images.change( fn=lambda x: "" if not x else "", # 当禁用图片时返回空字符串 inputs=[show_images], outputs=[images_container] ) with gr.Column(scale=6): with gr.Tabs(): with gr.TabItem("Travel Plan"): plan_output = gr.HTML(label="Generated Travel Plan", show_label=False) with gr.TabItem("References and Evaluation"): sources_output = gr.Markdown(label="References and Evaluation", show_label=False) # 修改示例为英文 gr.Examples( examples=[ ["One-day trip to Hong Kong Disneyland", 1], ["Family trip to Hong Kong Ocean Park", 1], ["Hong Kong Shopping and Food Tour", 2], ["Hong Kong Cultural Experience Tour", 3] ], inputs=[query_input, days_input], label="Example Queries" ) def show_loading(): loading_html = "Generating your personalized travel plan...