|
""" |
|
优化调度器 - 负责管理LLM优化任务队列 |
|
""" |
|
|
|
import asyncio |
|
import logging |
|
from concurrent.futures import ThreadPoolExecutor |
|
from typing import Callable, Optional |
|
from optimizer.llm_api_runner import ChatGPTRunner |
|
from optimizer.optimize_task import OptimizeTask |
|
|
|
|
|
def setup_logger(name, level=logging.INFO): |
|
"""设置日志记录器""" |
|
logger = logging.getLogger(name) |
|
|
|
if logger.handlers: |
|
logger.handlers.clear() |
|
|
|
|
|
handler = logging.StreamHandler() |
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
handler.setFormatter(formatter) |
|
logger.addHandler(handler) |
|
logger.setLevel(level) |
|
|
|
logger.propagate = False |
|
return logger |
|
|
|
|
|
logger = setup_logger("optimizer") |
|
|
|
class OptimizationDispatcher: |
|
""" |
|
优化调度器,负责管理LLM优化任务队列 |
|
支持异步处理多个优化任务 |
|
""" |
|
|
|
def __init__(self, max_workers: int = 2, callback: Optional[Callable] = None): |
|
""" |
|
初始化优化调度器 |
|
|
|
:param max_workers: 最大工作线程数 |
|
:param callback: 优化完成后的回调函数 |
|
""" |
|
self.tasks = {} |
|
self.executor = ThreadPoolExecutor(max_workers=max_workers) |
|
self.model_runner = ChatGPTRunner() |
|
self.callback = callback |
|
logger.debug(f"优化调度器初始化完成,最大工作线程数: {max_workers}") |
|
|
|
def submit(self, sentence_id: str, text: str, callback: Optional[Callable] = None): |
|
""" |
|
提交优化任务 |
|
|
|
:param sentence_id: 句子ID |
|
:param text: 需要优化的文本 |
|
:param callback: 优化完成后的回调函数,如果为None则使用默认回调 |
|
""" |
|
task_callback = callback or self.callback |
|
task = OptimizeTask(sentence_id, text, task_callback) |
|
self.tasks[sentence_id] = task |
|
logger.debug(f"提交优化任务: {sentence_id}") |
|
|
|
|
|
self.executor.submit(self._process_task, task) |
|
logger.debug(f"任务已提交到线程池: {sentence_id}") |
|
|
|
def _process_task(self, task: OptimizeTask): |
|
""" |
|
处理优化任务 |
|
|
|
:param task: 优化任务 |
|
""" |
|
try: |
|
logger.debug(f"开始处理任务: {task.sentence_id}") |
|
|
|
optimized_text = self.model_runner.optimize(task.text) |
|
logger.debug(f"任务处理完成: {task.sentence_id}") |
|
|
|
|
|
if task.callback: |
|
task.callback(task.sentence_id, task.text, optimized_text) |
|
logger.debug(f"已调用回调函数: {task.sentence_id}") |
|
|
|
|
|
if task.sentence_id in self.tasks: |
|
del self.tasks[task.sentence_id] |
|
|
|
logger.info(f"优化任务完成: {task.sentence_id}") |
|
except Exception as e: |
|
logger.error(f"处理任务出错: {task.sentence_id}, 错误: {str(e)}") |
|
|
|
def wait_until_done(self, timeout: Optional[float] = None): |
|
""" |
|
等待所有任务完成 |
|
|
|
:param timeout: 超时时间(秒),如果为None则一直等待 |
|
:return: 是否所有任务都已完成 |
|
""" |
|
logger.debug(f"等待所有任务完成,当前任务数: {len(self.tasks)}") |
|
self.executor.shutdown(wait=True, timeout=timeout) |
|
|
|
self.executor = ThreadPoolExecutor(max_workers=self.executor._max_workers) |
|
logger.debug("所有任务已完成") |
|
return True |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
logger.setLevel(logging.DEBUG) |
|
|
|
|
|
def test_callback(sentence_id, original_text, optimized_text): |
|
logger.info(f"[回填] {sentence_id}: {optimized_text}") |
|
|
|
|
|
dispatcher = OptimizationDispatcher(callback=test_callback) |
|
|
|
|
|
dispatcher.submit("s001", "we maybe start tomorrow okay") |
|
dispatcher.submit("s002", "they need eat fast meeting now") |
|
|
|
|
|
dispatcher.wait_until_done() |
|
|
|
logger.info("测试完成") |
|
|