faster-whisper-small / aggregator /semantic_aggregator.py
tianyaogavin's picture
init main framework
1bf36cc
from typing import List, Callable, Optional, Dict, Tuple
import uuid
import time
import os
import numpy as np
import soundfile as sf
import logging
from openai import OpenAI
from transcribe.transcribe import TranscriptionResult, AudioTranscriber
# 配置日志
def setup_logger(name, level=logging.INFO):
"""设置日志记录器"""
logger = logging.getLogger(name)
# 清除所有已有的handler,避免重复
if logger.handlers:
logger.handlers.clear()
# 添加新的handler
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(level)
# 禁止传播到父logger,避免重复日志
logger.propagate = False
return logger
# 创建日志记录器
logger = setup_logger("aggregator")
class SentenceCompletionDetector:
"""
使用ChatGPT判断句子是否完整
"""
def __init__(self, model="gpt-3.5-turbo"):
self.model = model
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
def build_prompt(self, text: str) -> str:
return (
"判断以下语句是否为一句话的结尾,如果是,返回 True,否则返回 False:\n"
"\"你会学习到如何使用音频数据集,包括音频数据加载\"\n"
"False\n\n"
"\"你会学习到如何使用音频数据集,包括音频数据加载,音频数据预处理,以及高效加载大规模音频数据集的流式加载方法\"\n"
"True\n\n"
"\"在开始学习之前,我们需要\"\n"
"False\n\n"
"\"在开始学习之前,我们需要了解一些基本概念\"\n"
"True\n\n"
"\"第一章,介绍基础知识\"\n"
"True\n\n"
f"\"{text}\"\n"
)
def is_sentence_complete(self, text: str) -> bool:
"""
判断文本是否是一个完整的句子
"""
# # 简单规则:如果以标点符号结尾,认为是完整的句子
# if text.strip() and text.strip()[-1] in "。!?!?.;;":
# return True
# 使用ChatGPT进行更复杂的判断
prompt = self.build_prompt(text)
try:
response = self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": "你是一个语言专家,擅长判断句子是否完整。"},
{"role": "user", "content": prompt}
],
temperature=0.1,
max_tokens=10,
)
result = response.choices[0].message.content.strip()
logger.debug(f"ChatGPT判断结果: {result}")
return result.lower() == "true"
except Exception as e:
logger.error(f"调用ChatGPT出错: {str(e)}")
# 出错时使用简单规则判断
return len(text) > 20 # 如果文本较长,可能是完整句子
class SemanticAggregator:
"""
语义聚合控制器
- 维护segment缓冲池
- 判断是否组成完整语义单元
- 推送到下游(display/translator)
"""
def __init__(
self,
on_display: Callable[[str, str, str], None],
on_translate: Callable[[str, str], None],
transcriber: AudioTranscriber,
segments_dir: str = "dataset/audio/segments",
max_window: float = 5.0,
max_segments: int = 5,
min_gap: float = 0.8,
force_flush_timeout: float = 3.0
):
"""
:param on_display: 显示回调 (sentence_id, text, state)
:param on_translate: 翻译回调 (sentence_id, text)
:param transcriber: 转录器实例
:param segments_dir: 音频片段目录
:param max_window: 最大聚合时长(秒)
:param max_segments: 最大聚合片段数
:param min_gap: 触发聚合的最小间隔(秒)
:param force_flush_timeout: 强制flush超时时间(秒)
"""
self.buffer: List[TranscriptionResult] = []
self.on_display = on_display
self.on_translate = on_translate
self.transcriber = transcriber
self.segments_dir = segments_dir
self.max_window = max_window
self.max_segments = max_segments
self.min_gap = min_gap
self.force_flush_timeout = force_flush_timeout
self.last_flush_time = time.time()
self.sentence_detector = SentenceCompletionDetector()
self.audio_cache: Dict[int, np.ndarray] = {} # 缓存音频数据,避免重复读取
self.sample_rate = 16000 # 假设采样率为16kHz
logger.debug(f"语义聚合器初始化完成,参数: max_window={max_window}, max_segments={max_segments}")
def add_segment(self, result: TranscriptionResult):
"""
新增转写片段到缓冲池,自动判断是否聚合
"""
self.buffer.append(result)
logger.debug(f"添加片段: {result.text}")
if self._should_aggregate():
self._aggregate_and_flush()
elif time.time() - self.last_flush_time > self.force_flush_timeout:
logger.debug(f"超时强制刷新: {self.force_flush_timeout}秒")
self.flush(force=True)
def flush(self, force: bool = False):
"""
强制输出当前聚合内容
"""
if self.buffer:
logger.debug(f"强制刷新缓冲区,当前片段数: {len(self.buffer)}")
self._aggregate_and_flush()
self.last_flush_time = time.time()
def _should_aggregate(self) -> bool:
"""
判断是否满足聚合条件
"""
if not self.buffer:
return False
# 1. 使用ChatGPT判断是否是完整句子
# 使用逗号连接segments,与_aggregate_and_flush保持一致
segments = [seg.text for seg in self.buffer]
combined_text = ",".join(segments)
if self.sentence_detector.is_sentence_complete(combined_text):
logger.info(f"检测到完整句子: {combined_text}")
return True
# 2. segment间隔
if len(self.buffer) >= 2:
gap = self.buffer[-1].start_time - self.buffer[-2].end_time
if gap > self.min_gap:
logger.info(f"检测到较大间隔: {gap:.2f}秒")
return True
# 3. 最大窗口/片段数
total_duration = self.buffer[-1].end_time - self.buffer[0].start_time
if total_duration > self.max_window:
logger.info(f"达到最大时间窗口: {total_duration:.2f}秒")
return True
if len(self.buffer) >= self.max_segments:
logger.info(f"达到最大片段数: {len(self.buffer)}")
return True
return False
def _get_segment_audio(self, segment_index: int) -> np.ndarray:
"""
获取指定索引的音频片段数据
"""
if segment_index in self.audio_cache:
return self.audio_cache[segment_index]
# 读取音频文件
audio_path = os.path.join(self.segments_dir, f"test1_segment_{segment_index}.wav")
try:
audio_data, sample_rate = sf.read(audio_path)
self.audio_cache[segment_index] = audio_data
logger.debug(f"读取音频片段: {audio_path}, 长度: {len(audio_data)/sample_rate:.2f}秒")
return audio_data
except Exception as e:
logger.error(f"读取音频文件失败: {audio_path}, 错误: {str(e)}")
return np.array([])
def _combine_audio_segments(self, segment_indices: List[int]) -> Tuple[np.ndarray, float]:
"""
合并多个音频片段
返回: (合并后的音频数据, 起始时间)
"""
if not segment_indices:
return np.array([]), 0.0
# 获取所有片段的音频数据
audio_segments = []
for idx in segment_indices:
audio_data = self._get_segment_audio(idx)
if len(audio_data) > 0:
audio_segments.append(audio_data)
if not audio_segments:
return np.array([]), 0.0
# 合并音频数据
combined_audio = np.concatenate(audio_segments)
# 获取第一个片段的起始时间
first_segment = self.buffer[0]
start_time = first_segment.start_time
logger.debug(f"合并音频片段: {segment_indices}, 总长度: {len(combined_audio)/self.sample_rate:.2f}秒")
return combined_audio, start_time
def _retranscribe_segments(self, segment_indices: List[int]) -> List[TranscriptionResult]:
"""
重新转录合并后的音频片段
"""
combined_audio, start_time = self._combine_audio_segments(segment_indices)
if len(combined_audio) == 0:
logger.warning("没有有效的音频数据可以重新转录")
return []
logger.debug(f"重新转录合并的音频片段, 长度: {len(combined_audio)/self.sample_rate:.2f}秒")
try:
results = self.transcriber.transcribe_segment(combined_audio, start_time=start_time)
logger.debug(f"重新转录结果: {len(results)}条")
return results
except Exception as e:
logger.error(f"重新转录失败: {str(e)}")
return []
def _aggregate_and_flush(self):
"""
聚合并推送到下游
"""
if not self.buffer:
return
# 获取所有片段的索引
segment_indices = []
for seg in self.buffer:
if hasattr(seg, 'segment_index') and seg.segment_index is not None:
if isinstance(seg.segment_index, list):
segment_indices.extend(seg.segment_index)
else:
segment_indices.append(seg.segment_index)
# 去重并排序
segment_indices = sorted(list(set(segment_indices)))
# 生成句子ID
sentence_id = str(uuid.uuid4())
# 1. 先使用原始文本进行输出,在segment之间添加逗号
original_segments = [seg.text for seg in self.buffer]
# 使用逗号连接segments,但不在最后添加句号
original_text = ",".join(original_segments)
logger.info(f"原始聚合文本: {original_text}")
self.on_display(sentence_id, original_text, "raw")
# 2. 重新转录
if segment_indices:
retranscribed_results = self._retranscribe_segments(segment_indices)
if retranscribed_results:
# 合并重新转录的结果,在segment之间添加逗号
retranscribed_segments = [res.text for res in retranscribed_results]
retranscribed_text = ",".join(retranscribed_segments)
logger.info(f"重新转录文本: {retranscribed_text}")
# 如果重新转录的结果与原始文本不同,则更新显示
if retranscribed_text != original_text:
self.on_display(sentence_id, retranscribed_text, "optimized")
# 发送到翻译模块
self.on_translate(sentence_id, retranscribed_text)
else:
# 如果重新转录失败,使用原始文本进行翻译
logger.warning("重新转录失败,使用原始文本进行翻译")
self.on_translate(sentence_id, original_text)
else:
# 如果没有有效的片段索引,使用原始文本进行翻译
logger.warning("没有有效的片段索引,使用原始文本进行翻译")
self.on_translate(sentence_id, original_text)
# 清空缓冲区
buffer_size = len(self.buffer)
self.buffer.clear()
self.last_flush_time = time.time()
logger.debug(f"清空缓冲区,释放 {buffer_size} 个片段")
def load_transcription_results(json_path):
"""从JSON文件加载转录结果"""
import json
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
results = []
for segment in data['segments']:
result = TranscriptionResult(
text=segment['text'],
start_time=segment['start_time'],
end_time=segment['end_time'],
confidence=segment['confidence'],
verified=segment['verified'],
verified_text=segment['verified_text'],
verification_notes=segment['verification_notes'],
segment_index=segment['segment_index'] if 'segment_index' in segment else None
)
results.append(result)
return results
if __name__ == "__main__":
"""测试聚合器功能"""
import os
import sys
import json
from pathlib import Path
# 配置日志级别
logger.setLevel(logging.DEBUG)
# 检查OpenAI API密钥
if not os.getenv("OPENAI_API_KEY"):
logger.warning("未设置OPENAI_API_KEY环境变量,句子完整性判断将使用备用方法")
# 初始化各个模块
from display.display import OutputRenderer
from translator.translator import NLLBTranslator
from transcribe.transcribe import AudioTranscriber
# 初始化显示器
renderer = OutputRenderer()
# 初始化转录器
try:
transcriber = AudioTranscriber(model="small", device="cuda", compute_type="int8")
logger.info("使用GPU进行转录")
except Exception as e:
logger.warning(f"GPU初始化失败,使用CPU: {str(e)}")
transcriber = AudioTranscriber(model="small", device="cpu", compute_type="float32")
# 初始化翻译器(可选)
try:
translator = NLLBTranslator()
translation_enabled = True
except Exception as e:
logger.warning(f"翻译器初始化失败: {str(e)}")
translation_enabled = False
# 回调函数
def display_callback(sentence_id, text, state):
renderer.display(sentence_id, text, state)
def translate_callback(sentence_id, text):
if translation_enabled:
try:
translation = translator.translate(text)
logger.info(f"[翻译] 句子 {sentence_id}: {translation}")
except Exception as e:
logger.error(f"翻译失败: {str(e)}")
else:
logger.info(f"[翻译已禁用] 句子 {sentence_id}: {text}")
# 初始化聚合器
aggregator = SemanticAggregator(
on_display=display_callback,
on_translate=translate_callback,
transcriber=transcriber,
segments_dir="dataset/audio/segments",
max_window=10.0, # 增大窗口以便测试
max_segments=10, # 增大片段数以便测试
force_flush_timeout=5.0 # 增大超时以便测试
)
# 加载测试数据
test_file = "dataset/transcripts/test1_segment_1_20250423_201934.json"
try:
results = load_transcription_results(test_file)
logger.info(f"加载了 {len(results)} 条转录结果")
except Exception as e:
logger.error(f"加载转录结果失败: {str(e)}")
sys.exit(1)
# 模拟添加转录结果
for i, result in enumerate(results):
logger.info(f"添加第 {i+1}/{len(results)} 条转录结果: {result.text}")
aggregator.add_segment(result)
# 模拟处理延迟
# time.sleep(0.5)
# 强制刷新缓冲区
aggregator.flush(force=True)
logger.info("测试完成")