""" 伪流式音频转写 + LLM优化系统 主程序 这个程序实现了完整的音频处理流水线,包括: 1. VAD分段 2. Whisper转录 3. 语义聚合 4. 即时输出 5. LLM优化 6. 翻译 使用方法: python main.py [--audio_path AUDIO_PATH] [--use_gpu] [--enable_translation] [--enable_optimization] """ import os import sys import time import logging import argparse import numpy as np import soundfile as sf from pathlib import Path from typing import List, Dict, Optional, Tuple, Union import uuid # 配置日志 def setup_logger(name, level=logging.INFO): """设置日志记录器""" logger = logging.getLogger(name) # 清除所有已有的handler,避免重复 if logger.handlers: logger.handlers.clear() # 添加新的handler handler = logging.StreamHandler() formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') handler.setFormatter(formatter) logger.addHandler(handler) logger.setLevel(level) # 禁止传播到父logger,避免重复日志 logger.propagate = False return logger # 创建主日志记录器 logger = setup_logger("main") # 导入各个模块 from vad import VoiceActivityDetector from transcribe.transcribe import AudioTranscriber, TranscriptionResult from aggregator.semantic_aggregator import SemanticAggregator from display.display import OutputRenderer from optimizer.dispatcher import OptimizationDispatcher from translator.translator import NLLBTranslator class AudioProcessingPipeline: """完整的音频处理流水线""" def __init__( self, audio_path: str, use_gpu: bool = True, enable_translation: bool = True, enable_optimization: bool = True, whisper_model: str = "large", log_level: Union[int, str] = logging.INFO ): """ 初始化处理流水线 :param audio_path: 音频文件路径 :param use_gpu: 是否使用GPU :param enable_translation: 是否启用翻译 :param enable_optimization: 是否启用LLM优化 :param whisper_model: Whisper模型大小 (tiny, base, small, medium, large) :param log_level: 日志级别 """ # 设置日志级别 if isinstance(log_level, str): log_level = getattr(logging, log_level.upper()) logger.setLevel(log_level) self.audio_path = audio_path self.use_gpu = use_gpu self.enable_translation = enable_translation self.enable_optimization = enable_optimization self.whisper_model = whisper_model # 检查设备 self.device = "cuda" if use_gpu and self._is_gpu_available() else "cpu" logger.info(f"使用设备: {self.device}") logger.debug(f"配置: whisper_model={whisper_model}, translation={enable_translation}, optimization={enable_optimization}") # 初始化各个模块 self._init_modules() # 加载音频数据 self.audio_data, self.sample_rate = sf.read(audio_path) logger.info(f"加载音频: {os.path.basename(audio_path)}, 长度: {len(self.audio_data)/self.sample_rate:.2f}秒") logger.debug(f"音频详情: 采样率={self.sample_rate}Hz, 形状={self.audio_data.shape}") # 存储句子ID到优化任务的映射 self.optimization_tasks: Dict[str, str] = {} def _is_gpu_available(self) -> bool: """检查GPU是否可用""" try: import torch if torch.cuda.is_available(): logger.debug(f"检测到GPU: {torch.cuda.get_device_name(0)}") return True return False except ImportError: logger.debug("未检测到PyTorch,将使用CPU") return False def _init_modules(self): """初始化各个处理模块""" # 1. 初始化VAD logger.debug("初始化VAD模块...") self.vad = VoiceActivityDetector(save_audio=True, save_json=True) # 2. 初始化转录器 logger.debug(f"初始化Whisper转录模块 (model={self.whisper_model}, device={self.device})...") self.transcriber = AudioTranscriber( model=self.whisper_model, device=self.device, compute_type="int8" if self.device == "cuda" else "float32" ) # 3. 初始化显示器 logger.debug("初始化显示模块...") self.renderer = OutputRenderer() # 4. 初始化优化调度器(如果启用) if self.enable_optimization: logger.debug("初始化LLM优化调度器...") self.optimizer = OptimizationDispatcher( max_workers=2, callback=self._optimization_callback ) else: logger.debug("LLM优化已禁用") self.optimizer = None # 5. 初始化翻译器(如果启用) if self.enable_translation: logger.debug("初始化翻译模块...") try: self.translator = NLLBTranslator() self.translation_enabled = True except Exception as e: logger.warning(f"翻译器初始化失败: {str(e)}") self.translation_enabled = False else: logger.debug("翻译已禁用") self.translation_enabled = False self.translator = None # 6. 初始化聚合器 logger.debug("初始化语义聚合控制器...") self.aggregator = SemanticAggregator( on_display=self._display_callback, on_translate=self._translate_callback, transcriber=self.transcriber, segments_dir="dataset/audio/segments", max_window=5.0, max_segments=5, min_gap=0.8, force_flush_timeout=3.0 ) logger.info("所有模块初始化完成") def _display_callback(self, sentence_id: str, text: str, state: str): """显示回调函数""" self.renderer.display(sentence_id, text, state) # 如果启用了优化,且是原始文本,则提交优化任务 if self.enable_optimization and state == "raw" and self.optimizer: logger.debug(f"提交优化任务: {sentence_id}") self.optimizer.submit(sentence_id, text) def _translate_callback(self, sentence_id: str, text: str): """翻译回调函数""" if self.translation_enabled and self.translator: try: # 翻译模块内部已经记录了原文和结果,这里只需调用翻译方法 self.translator.translate(text) logger.debug(f"已翻译句子: {sentence_id}") except Exception as e: logger.error(f"翻译失败: {str(e)}") def _optimization_callback(self, sentence_id: str, original_text: str, optimized_text: str): """优化回调函数""" logger.debug(f"收到优化结果: {sentence_id}") # 更新显示 self.renderer.display(sentence_id, optimized_text, "optimized") # 如果启用了翻译,则翻译优化后的文本 if self.translation_enabled: logger.debug(f"翻译优化后的文本: {sentence_id}") self._translate_callback(sentence_id, optimized_text) def process(self): """处理音频文件""" logger.info("开始处理音频...") # 1. VAD分段 logger.debug("执行VAD分段...") segments = self.vad.detect_voice_segments(self.audio_data, self.sample_rate) logger.info(f"VAD分段完成: {len(segments)}个片段") # 2. 转录每个片段 for i, (start, end) in enumerate(segments): logger.debug(f"转录片段 {i+1}/{len(segments)}: {start:.2f}s -> {end:.2f}s") # 提取片段音频数据 segment_audio = self.audio_data[int(start * self.sample_rate):int(end * self.sample_rate)] # 转录片段 results = self.transcriber.transcribe_segment(segment_audio, start_time=start) # 添加片段索引 for result in results: result.segment_index = i + 1 # 片段索引从1开始 # 添加到聚合器 for result in results: logger.debug(f"添加转录结果: {result.text}") self.aggregator.add_segment(result) # 模拟处理延迟 time.sleep(0.1) # 3. 最后强制刷新缓冲区 logger.debug("强制刷新缓冲区...") self.aggregator.flush(force=True) # 4. 等待所有优化任务完成 if self.enable_optimization and self.optimizer: logger.debug("等待所有优化任务完成...") self.optimizer.wait_until_done() logger.info("音频处理完成") def parse_args(): """解析命令行参数""" parser = argparse.ArgumentParser(description="伪流式音频转写 + LLM优化系统") parser.add_argument("--audio_path", type=str, default="dataset/audio/test1.wav", help="音频文件路径") parser.add_argument("--use_gpu", action="store_true", default=True, help="是否使用GPU") parser.add_argument("--enable_translation", action="store_true", default=True, help="是否启用翻译") parser.add_argument("--enable_optimization", action="store_true", default=True, help="是否启用LLM优化") parser.add_argument("--whisper_model", type=str, default="small", choices=["tiny", "base", "small", "medium", "large"], help="Whisper模型大小") parser.add_argument("--log_level", type=str, default="INFO", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="日志级别") return parser.parse_args() def main(): """主函数""" # 解析命令行参数 args = parse_args() # 设置日志级别 log_level = getattr(logging, args.log_level) # 设置所有模块的日志级别 for module in ["main", "vad", "transcribe", "aggregator", "display", "optimizer", "translator"]: setup_logger(module, log_level) # 检查OpenAI API密钥(用于句子完整性判断和优化) if not os.getenv("OPENAI_API_KEY") and args.enable_optimization: logger.warning("未设置OPENAI_API_KEY环境变量,句子完整性判断将使用备用方法") # 检查音频文件是否存在 if not os.path.exists(args.audio_path): logger.error(f"音频文件不存在: {args.audio_path}") return # 创建并运行处理流水线 pipeline = AudioProcessingPipeline( audio_path=args.audio_path, use_gpu=args.use_gpu, enable_translation=args.enable_translation, enable_optimization=args.enable_optimization, whisper_model=args.whisper_model, log_level=log_level ) # 处理音频 pipeline.process() if __name__ == "__main__": main()