|
""" |
|
伪流式音频转写 + LLM优化系统 主程序 |
|
|
|
这个程序实现了完整的音频处理流水线,包括: |
|
1. VAD分段 |
|
2. Whisper转录 |
|
3. 语义聚合 |
|
4. 即时输出 |
|
5. LLM优化 |
|
6. 翻译 |
|
|
|
使用方法: |
|
python main.py [--audio_path AUDIO_PATH] [--use_gpu] [--enable_translation] [--enable_optimization] |
|
""" |
|
|
|
import os |
|
import sys |
|
import time |
|
import logging |
|
import argparse |
|
import numpy as np |
|
import soundfile as sf |
|
from pathlib import Path |
|
from typing import List, Dict, Optional, Tuple, Union |
|
import uuid |
|
|
|
|
|
def setup_logger(name, level=logging.INFO): |
|
"""设置日志记录器""" |
|
logger = logging.getLogger(name) |
|
|
|
if logger.handlers: |
|
logger.handlers.clear() |
|
|
|
|
|
handler = logging.StreamHandler() |
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
handler.setFormatter(formatter) |
|
logger.addHandler(handler) |
|
logger.setLevel(level) |
|
|
|
logger.propagate = False |
|
return logger |
|
|
|
|
|
logger = setup_logger("main") |
|
|
|
|
|
from vad import VoiceActivityDetector |
|
from transcribe.transcribe import AudioTranscriber, TranscriptionResult |
|
from aggregator.semantic_aggregator import SemanticAggregator |
|
from display.display import OutputRenderer |
|
from optimizer.dispatcher import OptimizationDispatcher |
|
from translator.translator import NLLBTranslator |
|
|
|
class AudioProcessingPipeline: |
|
"""完整的音频处理流水线""" |
|
|
|
def __init__( |
|
self, |
|
audio_path: str, |
|
use_gpu: bool = True, |
|
enable_translation: bool = True, |
|
enable_optimization: bool = True, |
|
whisper_model: str = "large", |
|
log_level: Union[int, str] = logging.INFO |
|
): |
|
""" |
|
初始化处理流水线 |
|
|
|
:param audio_path: 音频文件路径 |
|
:param use_gpu: 是否使用GPU |
|
:param enable_translation: 是否启用翻译 |
|
:param enable_optimization: 是否启用LLM优化 |
|
:param whisper_model: Whisper模型大小 (tiny, base, small, medium, large) |
|
:param log_level: 日志级别 |
|
""" |
|
|
|
if isinstance(log_level, str): |
|
log_level = getattr(logging, log_level.upper()) |
|
logger.setLevel(log_level) |
|
|
|
self.audio_path = audio_path |
|
self.use_gpu = use_gpu |
|
self.enable_translation = enable_translation |
|
self.enable_optimization = enable_optimization |
|
self.whisper_model = whisper_model |
|
|
|
|
|
self.device = "cuda" if use_gpu and self._is_gpu_available() else "cpu" |
|
logger.info(f"使用设备: {self.device}") |
|
logger.debug(f"配置: whisper_model={whisper_model}, translation={enable_translation}, optimization={enable_optimization}") |
|
|
|
|
|
self._init_modules() |
|
|
|
|
|
self.audio_data, self.sample_rate = sf.read(audio_path) |
|
logger.info(f"加载音频: {os.path.basename(audio_path)}, 长度: {len(self.audio_data)/self.sample_rate:.2f}秒") |
|
logger.debug(f"音频详情: 采样率={self.sample_rate}Hz, 形状={self.audio_data.shape}") |
|
|
|
|
|
self.optimization_tasks: Dict[str, str] = {} |
|
|
|
def _is_gpu_available(self) -> bool: |
|
"""检查GPU是否可用""" |
|
try: |
|
import torch |
|
if torch.cuda.is_available(): |
|
logger.debug(f"检测到GPU: {torch.cuda.get_device_name(0)}") |
|
return True |
|
return False |
|
except ImportError: |
|
logger.debug("未检测到PyTorch,将使用CPU") |
|
return False |
|
|
|
def _init_modules(self): |
|
"""初始化各个处理模块""" |
|
|
|
logger.debug("初始化VAD模块...") |
|
self.vad = VoiceActivityDetector(save_audio=True, save_json=True) |
|
|
|
|
|
logger.debug(f"初始化Whisper转录模块 (model={self.whisper_model}, device={self.device})...") |
|
self.transcriber = AudioTranscriber( |
|
model=self.whisper_model, |
|
device=self.device, |
|
compute_type="int8" if self.device == "cuda" else "float32" |
|
) |
|
|
|
|
|
logger.debug("初始化显示模块...") |
|
self.renderer = OutputRenderer() |
|
|
|
|
|
if self.enable_optimization: |
|
logger.debug("初始化LLM优化调度器...") |
|
self.optimizer = OptimizationDispatcher( |
|
max_workers=2, |
|
callback=self._optimization_callback |
|
) |
|
else: |
|
logger.debug("LLM优化已禁用") |
|
self.optimizer = None |
|
|
|
|
|
if self.enable_translation: |
|
logger.debug("初始化翻译模块...") |
|
try: |
|
self.translator = NLLBTranslator() |
|
self.translation_enabled = True |
|
except Exception as e: |
|
logger.warning(f"翻译器初始化失败: {str(e)}") |
|
self.translation_enabled = False |
|
else: |
|
logger.debug("翻译已禁用") |
|
self.translation_enabled = False |
|
self.translator = None |
|
|
|
|
|
logger.debug("初始化语义聚合控制器...") |
|
self.aggregator = SemanticAggregator( |
|
on_display=self._display_callback, |
|
on_translate=self._translate_callback, |
|
transcriber=self.transcriber, |
|
segments_dir="dataset/audio/segments", |
|
max_window=5.0, |
|
max_segments=5, |
|
min_gap=0.8, |
|
force_flush_timeout=3.0 |
|
) |
|
|
|
logger.info("所有模块初始化完成") |
|
|
|
def _display_callback(self, sentence_id: str, text: str, state: str): |
|
"""显示回调函数""" |
|
self.renderer.display(sentence_id, text, state) |
|
|
|
|
|
if self.enable_optimization and state == "raw" and self.optimizer: |
|
logger.debug(f"提交优化任务: {sentence_id}") |
|
self.optimizer.submit(sentence_id, text) |
|
|
|
def _translate_callback(self, sentence_id: str, text: str): |
|
"""翻译回调函数""" |
|
if self.translation_enabled and self.translator: |
|
try: |
|
|
|
self.translator.translate(text) |
|
logger.debug(f"已翻译句子: {sentence_id}") |
|
except Exception as e: |
|
logger.error(f"翻译失败: {str(e)}") |
|
|
|
def _optimization_callback(self, sentence_id: str, original_text: str, optimized_text: str): |
|
"""优化回调函数""" |
|
logger.debug(f"收到优化结果: {sentence_id}") |
|
|
|
self.renderer.display(sentence_id, optimized_text, "optimized") |
|
|
|
if self.translation_enabled: |
|
logger.debug(f"翻译优化后的文本: {sentence_id}") |
|
self._translate_callback(sentence_id, optimized_text) |
|
|
|
def process(self): |
|
"""处理音频文件""" |
|
logger.info("开始处理音频...") |
|
|
|
|
|
logger.debug("执行VAD分段...") |
|
segments = self.vad.detect_voice_segments(self.audio_data, self.sample_rate) |
|
logger.info(f"VAD分段完成: {len(segments)}个片段") |
|
|
|
|
|
for i, (start, end) in enumerate(segments): |
|
logger.debug(f"转录片段 {i+1}/{len(segments)}: {start:.2f}s -> {end:.2f}s") |
|
|
|
|
|
segment_audio = self.audio_data[int(start * self.sample_rate):int(end * self.sample_rate)] |
|
|
|
|
|
results = self.transcriber.transcribe_segment(segment_audio, start_time=start) |
|
|
|
|
|
for result in results: |
|
result.segment_index = i + 1 |
|
|
|
|
|
for result in results: |
|
logger.debug(f"添加转录结果: {result.text}") |
|
self.aggregator.add_segment(result) |
|
|
|
|
|
time.sleep(0.1) |
|
|
|
|
|
logger.debug("强制刷新缓冲区...") |
|
self.aggregator.flush(force=True) |
|
|
|
|
|
if self.enable_optimization and self.optimizer: |
|
logger.debug("等待所有优化任务完成...") |
|
self.optimizer.wait_until_done() |
|
|
|
logger.info("音频处理完成") |
|
|
|
def parse_args(): |
|
"""解析命令行参数""" |
|
parser = argparse.ArgumentParser(description="伪流式音频转写 + LLM优化系统") |
|
parser.add_argument("--audio_path", type=str, default="dataset/audio/test1.wav", |
|
help="音频文件路径") |
|
parser.add_argument("--use_gpu", action="store_true", default=True, |
|
help="是否使用GPU") |
|
parser.add_argument("--enable_translation", action="store_true", default=True, |
|
help="是否启用翻译") |
|
parser.add_argument("--enable_optimization", action="store_true", default=True, |
|
help="是否启用LLM优化") |
|
parser.add_argument("--whisper_model", type=str, default="small", |
|
choices=["tiny", "base", "small", "medium", "large"], |
|
help="Whisper模型大小") |
|
parser.add_argument("--log_level", type=str, default="INFO", |
|
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], |
|
help="日志级别") |
|
return parser.parse_args() |
|
|
|
def main(): |
|
"""主函数""" |
|
|
|
args = parse_args() |
|
|
|
|
|
log_level = getattr(logging, args.log_level) |
|
|
|
|
|
for module in ["main", "vad", "transcribe", "aggregator", "display", "optimizer", "translator"]: |
|
setup_logger(module, log_level) |
|
|
|
|
|
if not os.getenv("OPENAI_API_KEY") and args.enable_optimization: |
|
logger.warning("未设置OPENAI_API_KEY环境变量,句子完整性判断将使用备用方法") |
|
|
|
|
|
if not os.path.exists(args.audio_path): |
|
logger.error(f"音频文件不存在: {args.audio_path}") |
|
return |
|
|
|
|
|
pipeline = AudioProcessingPipeline( |
|
audio_path=args.audio_path, |
|
use_gpu=args.use_gpu, |
|
enable_translation=args.enable_translation, |
|
enable_optimization=args.enable_optimization, |
|
whisper_model=args.whisper_model, |
|
log_level=log_level |
|
) |
|
|
|
|
|
pipeline.process() |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|