|
from typing import List, Callable, Optional, Dict, Tuple |
|
import uuid |
|
import time |
|
import os |
|
import numpy as np |
|
import soundfile as sf |
|
import logging |
|
from openai import OpenAI |
|
from transcribe.transcribe import TranscriptionResult, AudioTranscriber |
|
|
|
|
|
def setup_logger(name, level=logging.INFO): |
|
"""设置日志记录器""" |
|
logger = logging.getLogger(name) |
|
|
|
if logger.handlers: |
|
logger.handlers.clear() |
|
|
|
|
|
handler = logging.StreamHandler() |
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
handler.setFormatter(formatter) |
|
logger.addHandler(handler) |
|
logger.setLevel(level) |
|
|
|
logger.propagate = False |
|
return logger |
|
|
|
|
|
logger = setup_logger("aggregator") |
|
|
|
class SentenceCompletionDetector: |
|
""" |
|
使用ChatGPT判断句子是否完整 |
|
""" |
|
def __init__(self, model="gpt-3.5-turbo"): |
|
self.model = model |
|
self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) |
|
|
|
def build_prompt(self, text: str) -> str: |
|
return ( |
|
"判断以下语句是否为一句话的结尾,如果是,返回 True,否则返回 False:\n" |
|
"\"你会学习到如何使用音频数据集,包括音频数据加载\"\n" |
|
"False\n\n" |
|
"\"你会学习到如何使用音频数据集,包括音频数据加载,音频数据预处理,以及高效加载大规模音频数据集的流式加载方法\"\n" |
|
"True\n\n" |
|
"\"在开始学习之前,我们需要\"\n" |
|
"False\n\n" |
|
"\"在开始学习之前,我们需要了解一些基本概念\"\n" |
|
"True\n\n" |
|
"\"第一章,介绍基础知识\"\n" |
|
"True\n\n" |
|
f"\"{text}\"\n" |
|
) |
|
|
|
def is_sentence_complete(self, text: str) -> bool: |
|
""" |
|
判断文本是否是一个完整的句子 |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
prompt = self.build_prompt(text) |
|
try: |
|
response = self.client.chat.completions.create( |
|
model=self.model, |
|
messages=[ |
|
{"role": "system", "content": "你是一个语言专家,擅长判断句子是否完整。"}, |
|
{"role": "user", "content": prompt} |
|
], |
|
temperature=0.1, |
|
max_tokens=10, |
|
) |
|
result = response.choices[0].message.content.strip() |
|
logger.debug(f"ChatGPT判断结果: {result}") |
|
return result.lower() == "true" |
|
except Exception as e: |
|
logger.error(f"调用ChatGPT出错: {str(e)}") |
|
|
|
return len(text) > 20 |
|
|
|
class SemanticAggregator: |
|
""" |
|
语义聚合控制器 |
|
- 维护segment缓冲池 |
|
- 判断是否组成完整语义单元 |
|
- 推送到下游(display/translator) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
on_display: Callable[[str, str, str], None], |
|
on_translate: Callable[[str, str], None], |
|
transcriber: AudioTranscriber, |
|
segments_dir: str = "dataset/audio/segments", |
|
max_window: float = 5.0, |
|
max_segments: int = 5, |
|
min_gap: float = 0.8, |
|
force_flush_timeout: float = 3.0 |
|
): |
|
""" |
|
:param on_display: 显示回调 (sentence_id, text, state) |
|
:param on_translate: 翻译回调 (sentence_id, text) |
|
:param transcriber: 转录器实例 |
|
:param segments_dir: 音频片段目录 |
|
:param max_window: 最大聚合时长(秒) |
|
:param max_segments: 最大聚合片段数 |
|
:param min_gap: 触发聚合的最小间隔(秒) |
|
:param force_flush_timeout: 强制flush超时时间(秒) |
|
""" |
|
self.buffer: List[TranscriptionResult] = [] |
|
self.on_display = on_display |
|
self.on_translate = on_translate |
|
self.transcriber = transcriber |
|
self.segments_dir = segments_dir |
|
self.max_window = max_window |
|
self.max_segments = max_segments |
|
self.min_gap = min_gap |
|
self.force_flush_timeout = force_flush_timeout |
|
self.last_flush_time = time.time() |
|
self.sentence_detector = SentenceCompletionDetector() |
|
self.audio_cache: Dict[int, np.ndarray] = {} |
|
self.sample_rate = 16000 |
|
logger.debug(f"语义聚合器初始化完成,参数: max_window={max_window}, max_segments={max_segments}") |
|
|
|
def add_segment(self, result: TranscriptionResult): |
|
""" |
|
新增转写片段到缓冲池,自动判断是否聚合 |
|
""" |
|
self.buffer.append(result) |
|
logger.debug(f"添加片段: {result.text}") |
|
if self._should_aggregate(): |
|
self._aggregate_and_flush() |
|
elif time.time() - self.last_flush_time > self.force_flush_timeout: |
|
logger.debug(f"超时强制刷新: {self.force_flush_timeout}秒") |
|
self.flush(force=True) |
|
|
|
def flush(self, force: bool = False): |
|
""" |
|
强制输出当前聚合内容 |
|
""" |
|
if self.buffer: |
|
logger.debug(f"强制刷新缓冲区,当前片段数: {len(self.buffer)}") |
|
self._aggregate_and_flush() |
|
self.last_flush_time = time.time() |
|
|
|
def _should_aggregate(self) -> bool: |
|
""" |
|
判断是否满足聚合条件 |
|
""" |
|
if not self.buffer: |
|
return False |
|
|
|
|
|
|
|
segments = [seg.text for seg in self.buffer] |
|
combined_text = ",".join(segments) |
|
if self.sentence_detector.is_sentence_complete(combined_text): |
|
logger.info(f"检测到完整句子: {combined_text}") |
|
return True |
|
|
|
|
|
if len(self.buffer) >= 2: |
|
gap = self.buffer[-1].start_time - self.buffer[-2].end_time |
|
if gap > self.min_gap: |
|
logger.info(f"检测到较大间隔: {gap:.2f}秒") |
|
return True |
|
|
|
|
|
total_duration = self.buffer[-1].end_time - self.buffer[0].start_time |
|
if total_duration > self.max_window: |
|
logger.info(f"达到最大时间窗口: {total_duration:.2f}秒") |
|
return True |
|
if len(self.buffer) >= self.max_segments: |
|
logger.info(f"达到最大片段数: {len(self.buffer)}") |
|
return True |
|
|
|
return False |
|
|
|
def _get_segment_audio(self, segment_index: int) -> np.ndarray: |
|
""" |
|
获取指定索引的音频片段数据 |
|
""" |
|
if segment_index in self.audio_cache: |
|
return self.audio_cache[segment_index] |
|
|
|
|
|
audio_path = os.path.join(self.segments_dir, f"test1_segment_{segment_index}.wav") |
|
try: |
|
audio_data, sample_rate = sf.read(audio_path) |
|
self.audio_cache[segment_index] = audio_data |
|
logger.debug(f"读取音频片段: {audio_path}, 长度: {len(audio_data)/sample_rate:.2f}秒") |
|
return audio_data |
|
except Exception as e: |
|
logger.error(f"读取音频文件失败: {audio_path}, 错误: {str(e)}") |
|
return np.array([]) |
|
|
|
def _combine_audio_segments(self, segment_indices: List[int]) -> Tuple[np.ndarray, float]: |
|
""" |
|
合并多个音频片段 |
|
返回: (合并后的音频数据, 起始时间) |
|
""" |
|
if not segment_indices: |
|
return np.array([]), 0.0 |
|
|
|
|
|
audio_segments = [] |
|
for idx in segment_indices: |
|
audio_data = self._get_segment_audio(idx) |
|
if len(audio_data) > 0: |
|
audio_segments.append(audio_data) |
|
|
|
if not audio_segments: |
|
return np.array([]), 0.0 |
|
|
|
|
|
combined_audio = np.concatenate(audio_segments) |
|
|
|
|
|
first_segment = self.buffer[0] |
|
start_time = first_segment.start_time |
|
|
|
logger.debug(f"合并音频片段: {segment_indices}, 总长度: {len(combined_audio)/self.sample_rate:.2f}秒") |
|
return combined_audio, start_time |
|
|
|
def _retranscribe_segments(self, segment_indices: List[int]) -> List[TranscriptionResult]: |
|
""" |
|
重新转录合并后的音频片段 |
|
""" |
|
combined_audio, start_time = self._combine_audio_segments(segment_indices) |
|
if len(combined_audio) == 0: |
|
logger.warning("没有有效的音频数据可以重新转录") |
|
return [] |
|
|
|
logger.debug(f"重新转录合并的音频片段, 长度: {len(combined_audio)/self.sample_rate:.2f}秒") |
|
try: |
|
results = self.transcriber.transcribe_segment(combined_audio, start_time=start_time) |
|
logger.debug(f"重新转录结果: {len(results)}条") |
|
return results |
|
except Exception as e: |
|
logger.error(f"重新转录失败: {str(e)}") |
|
return [] |
|
|
|
def _aggregate_and_flush(self): |
|
""" |
|
聚合并推送到下游 |
|
""" |
|
if not self.buffer: |
|
return |
|
|
|
|
|
segment_indices = [] |
|
for seg in self.buffer: |
|
if hasattr(seg, 'segment_index') and seg.segment_index is not None: |
|
if isinstance(seg.segment_index, list): |
|
segment_indices.extend(seg.segment_index) |
|
else: |
|
segment_indices.append(seg.segment_index) |
|
|
|
|
|
segment_indices = sorted(list(set(segment_indices))) |
|
|
|
|
|
sentence_id = str(uuid.uuid4()) |
|
|
|
|
|
original_segments = [seg.text for seg in self.buffer] |
|
|
|
original_text = ",".join(original_segments) |
|
logger.info(f"原始聚合文本: {original_text}") |
|
self.on_display(sentence_id, original_text, "raw") |
|
|
|
|
|
if segment_indices: |
|
retranscribed_results = self._retranscribe_segments(segment_indices) |
|
if retranscribed_results: |
|
|
|
retranscribed_segments = [res.text for res in retranscribed_results] |
|
retranscribed_text = ",".join(retranscribed_segments) |
|
logger.info(f"重新转录文本: {retranscribed_text}") |
|
|
|
|
|
if retranscribed_text != original_text: |
|
self.on_display(sentence_id, retranscribed_text, "optimized") |
|
|
|
|
|
self.on_translate(sentence_id, retranscribed_text) |
|
else: |
|
|
|
logger.warning("重新转录失败,使用原始文本进行翻译") |
|
self.on_translate(sentence_id, original_text) |
|
else: |
|
|
|
logger.warning("没有有效的片段索引,使用原始文本进行翻译") |
|
self.on_translate(sentence_id, original_text) |
|
|
|
|
|
buffer_size = len(self.buffer) |
|
self.buffer.clear() |
|
self.last_flush_time = time.time() |
|
logger.debug(f"清空缓冲区,释放 {buffer_size} 个片段") |
|
|
|
|
|
def load_transcription_results(json_path): |
|
"""从JSON文件加载转录结果""" |
|
import json |
|
with open(json_path, 'r', encoding='utf-8') as f: |
|
data = json.load(f) |
|
|
|
results = [] |
|
for segment in data['segments']: |
|
result = TranscriptionResult( |
|
text=segment['text'], |
|
start_time=segment['start_time'], |
|
end_time=segment['end_time'], |
|
confidence=segment['confidence'], |
|
verified=segment['verified'], |
|
verified_text=segment['verified_text'], |
|
verification_notes=segment['verification_notes'], |
|
segment_index=segment['segment_index'] if 'segment_index' in segment else None |
|
) |
|
results.append(result) |
|
|
|
return results |
|
|
|
if __name__ == "__main__": |
|
"""测试聚合器功能""" |
|
import os |
|
import sys |
|
import json |
|
from pathlib import Path |
|
|
|
|
|
logger.setLevel(logging.DEBUG) |
|
|
|
|
|
if not os.getenv("OPENAI_API_KEY"): |
|
logger.warning("未设置OPENAI_API_KEY环境变量,句子完整性判断将使用备用方法") |
|
|
|
|
|
from display.display import OutputRenderer |
|
from translator.translator import NLLBTranslator |
|
from transcribe.transcribe import AudioTranscriber |
|
|
|
|
|
renderer = OutputRenderer() |
|
|
|
|
|
try: |
|
transcriber = AudioTranscriber(model="small", device="cuda", compute_type="int8") |
|
logger.info("使用GPU进行转录") |
|
except Exception as e: |
|
logger.warning(f"GPU初始化失败,使用CPU: {str(e)}") |
|
transcriber = AudioTranscriber(model="small", device="cpu", compute_type="float32") |
|
|
|
|
|
try: |
|
translator = NLLBTranslator() |
|
translation_enabled = True |
|
except Exception as e: |
|
logger.warning(f"翻译器初始化失败: {str(e)}") |
|
translation_enabled = False |
|
|
|
|
|
def display_callback(sentence_id, text, state): |
|
renderer.display(sentence_id, text, state) |
|
|
|
def translate_callback(sentence_id, text): |
|
if translation_enabled: |
|
try: |
|
translation = translator.translate(text) |
|
logger.info(f"[翻译] 句子 {sentence_id}: {translation}") |
|
except Exception as e: |
|
logger.error(f"翻译失败: {str(e)}") |
|
else: |
|
logger.info(f"[翻译已禁用] 句子 {sentence_id}: {text}") |
|
|
|
|
|
aggregator = SemanticAggregator( |
|
on_display=display_callback, |
|
on_translate=translate_callback, |
|
transcriber=transcriber, |
|
segments_dir="dataset/audio/segments", |
|
max_window=10.0, |
|
max_segments=10, |
|
force_flush_timeout=5.0 |
|
) |
|
|
|
|
|
test_file = "dataset/transcripts/test1_segment_1_20250423_201934.json" |
|
try: |
|
results = load_transcription_results(test_file) |
|
logger.info(f"加载了 {len(results)} 条转录结果") |
|
except Exception as e: |
|
logger.error(f"加载转录结果失败: {str(e)}") |
|
sys.exit(1) |
|
|
|
|
|
for i, result in enumerate(results): |
|
logger.info(f"添加第 {i+1}/{len(results)} 条转录结果: {result.text}") |
|
aggregator.add_segment(result) |
|
|
|
|
|
|
|
|
|
aggregator.flush(force=True) |
|
|
|
logger.info("测试完成") |
|
|