tianyaogavin's picture
init main framework
1bf36cc
"""
优化调度器 - 负责管理LLM优化任务队列
"""
import asyncio
import logging
from concurrent.futures import ThreadPoolExecutor
from typing import Callable, Optional
from optimizer.llm_api_runner import ChatGPTRunner
from optimizer.optimize_task import OptimizeTask
# 配置日志
def setup_logger(name, level=logging.INFO):
"""设置日志记录器"""
logger = logging.getLogger(name)
# 清除所有已有的handler,避免重复
if logger.handlers:
logger.handlers.clear()
# 添加新的handler
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(level)
# 禁止传播到父logger,避免重复日志
logger.propagate = False
return logger
# 创建日志记录器
logger = setup_logger("optimizer")
class OptimizationDispatcher:
"""
优化调度器,负责管理LLM优化任务队列
支持异步处理多个优化任务
"""
def __init__(self, max_workers: int = 2, callback: Optional[Callable] = None):
"""
初始化优化调度器
:param max_workers: 最大工作线程数
:param callback: 优化完成后的回调函数
"""
self.tasks = {} # 存储任务ID到任务的映射
self.executor = ThreadPoolExecutor(max_workers=max_workers)
self.model_runner = ChatGPTRunner()
self.callback = callback
logger.debug(f"优化调度器初始化完成,最大工作线程数: {max_workers}")
def submit(self, sentence_id: str, text: str, callback: Optional[Callable] = None):
"""
提交优化任务
:param sentence_id: 句子ID
:param text: 需要优化的文本
:param callback: 优化完成后的回调函数,如果为None则使用默认回调
"""
task_callback = callback or self.callback
task = OptimizeTask(sentence_id, text, task_callback)
self.tasks[sentence_id] = task
logger.debug(f"提交优化任务: {sentence_id}")
# 在线程池中执行任务
self.executor.submit(self._process_task, task)
logger.debug(f"任务已提交到线程池: {sentence_id}")
def _process_task(self, task: OptimizeTask):
"""
处理优化任务
:param task: 优化任务
"""
try:
logger.debug(f"开始处理任务: {task.sentence_id}")
# 使用模型运行器优化文本
optimized_text = self.model_runner.optimize(task.text)
logger.debug(f"任务处理完成: {task.sentence_id}")
# 调用回调函数
if task.callback:
task.callback(task.sentence_id, task.text, optimized_text)
logger.debug(f"已调用回调函数: {task.sentence_id}")
# 从任务列表中移除
if task.sentence_id in self.tasks:
del self.tasks[task.sentence_id]
logger.info(f"优化任务完成: {task.sentence_id}")
except Exception as e:
logger.error(f"处理任务出错: {task.sentence_id}, 错误: {str(e)}")
def wait_until_done(self, timeout: Optional[float] = None):
"""
等待所有任务完成
:param timeout: 超时时间(秒),如果为None则一直等待
:return: 是否所有任务都已完成
"""
logger.debug(f"等待所有任务完成,当前任务数: {len(self.tasks)}")
self.executor.shutdown(wait=True, timeout=timeout)
# 创建新的线程池
self.executor = ThreadPoolExecutor(max_workers=self.executor._max_workers)
logger.debug("所有任务已完成")
return True
if __name__ == "__main__":
# 设置日志级别为DEBUG以查看详细信息
logger.setLevel(logging.DEBUG)
# 测试回调函数
def test_callback(sentence_id, original_text, optimized_text):
logger.info(f"[回填] {sentence_id}: {optimized_text}")
# 创建调度器
dispatcher = OptimizationDispatcher(callback=test_callback)
# 提交测试任务
dispatcher.submit("s001", "we maybe start tomorrow okay")
dispatcher.submit("s002", "they need eat fast meeting now")
# 等待任务完成
dispatcher.wait_until_done()
logger.info("测试完成")