""" 优化调度器 - 负责管理LLM优化任务队列 """ import asyncio import logging from concurrent.futures import ThreadPoolExecutor from typing import Callable, Optional from optimizer.llm_api_runner import ChatGPTRunner from optimizer.optimize_task import OptimizeTask # 配置日志 def setup_logger(name, level=logging.INFO): """设置日志记录器""" logger = logging.getLogger(name) # 清除所有已有的handler,避免重复 if logger.handlers: logger.handlers.clear() # 添加新的handler handler = logging.StreamHandler() formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') handler.setFormatter(formatter) logger.addHandler(handler) logger.setLevel(level) # 禁止传播到父logger,避免重复日志 logger.propagate = False return logger # 创建日志记录器 logger = setup_logger("optimizer") class OptimizationDispatcher: """ 优化调度器,负责管理LLM优化任务队列 支持异步处理多个优化任务 """ def __init__(self, max_workers: int = 2, callback: Optional[Callable] = None): """ 初始化优化调度器 :param max_workers: 最大工作线程数 :param callback: 优化完成后的回调函数 """ self.tasks = {} # 存储任务ID到任务的映射 self.executor = ThreadPoolExecutor(max_workers=max_workers) self.model_runner = ChatGPTRunner() self.callback = callback logger.debug(f"优化调度器初始化完成,最大工作线程数: {max_workers}") def submit(self, sentence_id: str, text: str, callback: Optional[Callable] = None): """ 提交优化任务 :param sentence_id: 句子ID :param text: 需要优化的文本 :param callback: 优化完成后的回调函数,如果为None则使用默认回调 """ task_callback = callback or self.callback task = OptimizeTask(sentence_id, text, task_callback) self.tasks[sentence_id] = task logger.debug(f"提交优化任务: {sentence_id}") # 在线程池中执行任务 self.executor.submit(self._process_task, task) logger.debug(f"任务已提交到线程池: {sentence_id}") def _process_task(self, task: OptimizeTask): """ 处理优化任务 :param task: 优化任务 """ try: logger.debug(f"开始处理任务: {task.sentence_id}") # 使用模型运行器优化文本 optimized_text = self.model_runner.optimize(task.text) logger.debug(f"任务处理完成: {task.sentence_id}") # 调用回调函数 if task.callback: task.callback(task.sentence_id, task.text, optimized_text) logger.debug(f"已调用回调函数: {task.sentence_id}") # 从任务列表中移除 if task.sentence_id in self.tasks: del self.tasks[task.sentence_id] logger.info(f"优化任务完成: {task.sentence_id}") except Exception as e: logger.error(f"处理任务出错: {task.sentence_id}, 错误: {str(e)}") def wait_until_done(self, timeout: Optional[float] = None): """ 等待所有任务完成 :param timeout: 超时时间(秒),如果为None则一直等待 :return: 是否所有任务都已完成 """ logger.debug(f"等待所有任务完成,当前任务数: {len(self.tasks)}") self.executor.shutdown(wait=True, timeout=timeout) # 创建新的线程池 self.executor = ThreadPoolExecutor(max_workers=self.executor._max_workers) logger.debug("所有任务已完成") return True if __name__ == "__main__": # 设置日志级别为DEBUG以查看详细信息 logger.setLevel(logging.DEBUG) # 测试回调函数 def test_callback(sentence_id, original_text, optimized_text): logger.info(f"[回填] {sentence_id}: {optimized_text}") # 创建调度器 dispatcher = OptimizationDispatcher(callback=test_callback) # 提交测试任务 dispatcher.submit("s001", "we maybe start tomorrow okay") dispatcher.submit("s002", "they need eat fast meeting now") # 等待任务完成 dispatcher.wait_until_done() logger.info("测试完成")