File size: 15,971 Bytes
1bf36cc
526f2ef
 
1bf36cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526f2ef
 
 
 
 
 
1bf36cc
526f2ef
 
 
 
1bf36cc
 
 
 
526f2ef
 
 
 
 
 
1bf36cc
 
 
 
526f2ef
 
 
 
 
 
1bf36cc
 
 
 
526f2ef
 
 
 
 
1bf36cc
 
 
 
526f2ef
 
 
 
 
 
1bf36cc
526f2ef
 
 
1bf36cc
526f2ef
 
 
 
 
 
 
1bf36cc
526f2ef
 
 
 
 
 
 
 
 
1bf36cc
 
 
 
 
 
 
526f2ef
1bf36cc
526f2ef
 
 
 
1bf36cc
526f2ef
1bf36cc
526f2ef
 
1bf36cc
 
 
 
 
526f2ef
1bf36cc
526f2ef
 
1bf36cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526f2ef
 
 
 
1bf36cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526f2ef
1bf36cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526f2ef
 
1bf36cc
526f2ef
 
1bf36cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
526f2ef
1bf36cc
 
 
 
 
 
 
 
 
 
 
 
 
 
526f2ef
1bf36cc
 
 
 
526f2ef
1bf36cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
from typing import List, Callable, Optional, Dict, Tuple
import uuid
import time
import os
import numpy as np
import soundfile as sf
import logging
from openai import OpenAI
from transcribe.transcribe import TranscriptionResult, AudioTranscriber

# 配置日志
def setup_logger(name, level=logging.INFO):
    """设置日志记录器"""
    logger = logging.getLogger(name)
    # 清除所有已有的handler,避免重复
    if logger.handlers:
        logger.handlers.clear()
    
    # 添加新的handler
    handler = logging.StreamHandler()
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    logger.setLevel(level)
    # 禁止传播到父logger,避免重复日志
    logger.propagate = False
    return logger

# 创建日志记录器
logger = setup_logger("aggregator")

class SentenceCompletionDetector:
    """
    使用ChatGPT判断句子是否完整
    """
    def __init__(self, model="gpt-3.5-turbo"):
        self.model = model
        self.client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
    
    def build_prompt(self, text: str) -> str:
        return (
            "判断以下语句是否为一句话的结尾,如果是,返回 True,否则返回 False:\n"
            "\"你会学习到如何使用音频数据集,包括音频数据加载\"\n"
            "False\n\n"
            "\"你会学习到如何使用音频数据集,包括音频数据加载,音频数据预处理,以及高效加载大规模音频数据集的流式加载方法\"\n"
            "True\n\n"
            "\"在开始学习之前,我们需要\"\n"
            "False\n\n"
            "\"在开始学习之前,我们需要了解一些基本概念\"\n"
            "True\n\n"
            "\"第一章,介绍基础知识\"\n"
            "True\n\n"
            f"\"{text}\"\n"
        )
    
    def is_sentence_complete(self, text: str) -> bool:
        """
        判断文本是否是一个完整的句子
        """
        # # 简单规则:如果以标点符号结尾,认为是完整的句子
        # if text.strip() and text.strip()[-1] in "。!?!?.;;":
        #     return True
            
        # 使用ChatGPT进行更复杂的判断
        prompt = self.build_prompt(text)
        try:
            response = self.client.chat.completions.create(
                model=self.model,
                messages=[
                    {"role": "system", "content": "你是一个语言专家,擅长判断句子是否完整。"},
                    {"role": "user", "content": prompt}
                ],
                temperature=0.1,
                max_tokens=10,
            )
            result = response.choices[0].message.content.strip()
            logger.debug(f"ChatGPT判断结果: {result}")
            return result.lower() == "true"
        except Exception as e:
            logger.error(f"调用ChatGPT出错: {str(e)}")
            # 出错时使用简单规则判断
            return len(text) > 20  # 如果文本较长,可能是完整句子

class SemanticAggregator:
    """
    语义聚合控制器
    - 维护segment缓冲池
    - 判断是否组成完整语义单元
    - 推送到下游(display/translator)
    """

    def __init__(
        self,
        on_display: Callable[[str, str, str], None],
        on_translate: Callable[[str, str], None],
        transcriber: AudioTranscriber,
        segments_dir: str = "dataset/audio/segments",
        max_window: float = 5.0,
        max_segments: int = 5,
        min_gap: float = 0.8,
        force_flush_timeout: float = 3.0
    ):
        """
        :param on_display: 显示回调 (sentence_id, text, state)
        :param on_translate: 翻译回调 (sentence_id, text)
        :param transcriber: 转录器实例
        :param segments_dir: 音频片段目录
        :param max_window: 最大聚合时长(秒)
        :param max_segments: 最大聚合片段数
        :param min_gap: 触发聚合的最小间隔(秒)
        :param force_flush_timeout: 强制flush超时时间(秒)
        """
        self.buffer: List[TranscriptionResult] = []
        self.on_display = on_display
        self.on_translate = on_translate
        self.transcriber = transcriber
        self.segments_dir = segments_dir
        self.max_window = max_window
        self.max_segments = max_segments
        self.min_gap = min_gap
        self.force_flush_timeout = force_flush_timeout
        self.last_flush_time = time.time()
        self.sentence_detector = SentenceCompletionDetector()
        self.audio_cache: Dict[int, np.ndarray] = {}  # 缓存音频数据,避免重复读取
        self.sample_rate = 16000  # 假设采样率为16kHz
        logger.debug(f"语义聚合器初始化完成,参数: max_window={max_window}, max_segments={max_segments}")

    def add_segment(self, result: TranscriptionResult):
        """
        新增转写片段到缓冲池,自动判断是否聚合
        """
        self.buffer.append(result)
        logger.debug(f"添加片段: {result.text}")
        if self._should_aggregate():
            self._aggregate_and_flush()
        elif time.time() - self.last_flush_time > self.force_flush_timeout:
            logger.debug(f"超时强制刷新: {self.force_flush_timeout}秒")
            self.flush(force=True)

    def flush(self, force: bool = False):
        """
        强制输出当前聚合内容
        """
        if self.buffer:
            logger.debug(f"强制刷新缓冲区,当前片段数: {len(self.buffer)}")
            self._aggregate_and_flush()
        self.last_flush_time = time.time()

    def _should_aggregate(self) -> bool:
        """
        判断是否满足聚合条件
        """
        if not self.buffer:
            return False
            
        # 1. 使用ChatGPT判断是否是完整句子
        # 使用逗号连接segments,与_aggregate_and_flush保持一致
        segments = [seg.text for seg in self.buffer]
        combined_text = ",".join(segments)
        if self.sentence_detector.is_sentence_complete(combined_text):
            logger.info(f"检测到完整句子: {combined_text}")
            return True
            
        # 2. segment间隔
        if len(self.buffer) >= 2:
            gap = self.buffer[-1].start_time - self.buffer[-2].end_time
            if gap > self.min_gap:
                logger.info(f"检测到较大间隔: {gap:.2f}秒")
                return True
                
        # 3. 最大窗口/片段数
        total_duration = self.buffer[-1].end_time - self.buffer[0].start_time
        if total_duration > self.max_window:
            logger.info(f"达到最大时间窗口: {total_duration:.2f}秒")
            return True
        if len(self.buffer) >= self.max_segments:
            logger.info(f"达到最大片段数: {len(self.buffer)}")
            return True
            
        return False

    def _get_segment_audio(self, segment_index: int) -> np.ndarray:
        """
        获取指定索引的音频片段数据
        """
        if segment_index in self.audio_cache:
            return self.audio_cache[segment_index]
            
        # 读取音频文件
        audio_path = os.path.join(self.segments_dir, f"test1_segment_{segment_index}.wav")
        try:
            audio_data, sample_rate = sf.read(audio_path)
            self.audio_cache[segment_index] = audio_data
            logger.debug(f"读取音频片段: {audio_path}, 长度: {len(audio_data)/sample_rate:.2f}秒")
            return audio_data
        except Exception as e:
            logger.error(f"读取音频文件失败: {audio_path}, 错误: {str(e)}")
            return np.array([])

    def _combine_audio_segments(self, segment_indices: List[int]) -> Tuple[np.ndarray, float]:
        """
        合并多个音频片段
        返回: (合并后的音频数据, 起始时间)
        """
        if not segment_indices:
            return np.array([]), 0.0
            
        # 获取所有片段的音频数据
        audio_segments = []
        for idx in segment_indices:
            audio_data = self._get_segment_audio(idx)
            if len(audio_data) > 0:
                audio_segments.append(audio_data)
                
        if not audio_segments:
            return np.array([]), 0.0
            
        # 合并音频数据
        combined_audio = np.concatenate(audio_segments)
        
        # 获取第一个片段的起始时间
        first_segment = self.buffer[0]
        start_time = first_segment.start_time
        
        logger.debug(f"合并音频片段: {segment_indices}, 总长度: {len(combined_audio)/self.sample_rate:.2f}秒")
        return combined_audio, start_time

    def _retranscribe_segments(self, segment_indices: List[int]) -> List[TranscriptionResult]:
        """
        重新转录合并后的音频片段
        """
        combined_audio, start_time = self._combine_audio_segments(segment_indices)
        if len(combined_audio) == 0:
            logger.warning("没有有效的音频数据可以重新转录")
            return []
            
        logger.debug(f"重新转录合并的音频片段, 长度: {len(combined_audio)/self.sample_rate:.2f}秒")
        try:
            results = self.transcriber.transcribe_segment(combined_audio, start_time=start_time)
            logger.debug(f"重新转录结果: {len(results)}条")
            return results
        except Exception as e:
            logger.error(f"重新转录失败: {str(e)}")
            return []

    def _aggregate_and_flush(self):
        """
        聚合并推送到下游
        """
        if not self.buffer:
            return
            
        # 获取所有片段的索引
        segment_indices = []
        for seg in self.buffer:
            if hasattr(seg, 'segment_index') and seg.segment_index is not None:
                if isinstance(seg.segment_index, list):
                    segment_indices.extend(seg.segment_index)
                else:
                    segment_indices.append(seg.segment_index)
            
        # 去重并排序
        segment_indices = sorted(list(set(segment_indices)))
        
        # 生成句子ID
        sentence_id = str(uuid.uuid4())
        
        # 1. 先使用原始文本进行输出,在segment之间添加逗号
        original_segments = [seg.text for seg in self.buffer]
        # 使用逗号连接segments,但不在最后添加句号
        original_text = ",".join(original_segments)
        logger.info(f"原始聚合文本: {original_text}")
        self.on_display(sentence_id, original_text, "raw")
        
        # 2. 重新转录
        if segment_indices:
            retranscribed_results = self._retranscribe_segments(segment_indices)
            if retranscribed_results:
                # 合并重新转录的结果,在segment之间添加逗号
                retranscribed_segments = [res.text for res in retranscribed_results]
                retranscribed_text = ",".join(retranscribed_segments)
                logger.info(f"重新转录文本: {retranscribed_text}")
                
                # 如果重新转录的结果与原始文本不同,则更新显示
                if retranscribed_text != original_text:
                    self.on_display(sentence_id, retranscribed_text, "optimized")
                    
                # 发送到翻译模块
                self.on_translate(sentence_id, retranscribed_text)
            else:
                # 如果重新转录失败,使用原始文本进行翻译
                logger.warning("重新转录失败,使用原始文本进行翻译")
                self.on_translate(sentence_id, original_text)
        else:
            # 如果没有有效的片段索引,使用原始文本进行翻译
            logger.warning("没有有效的片段索引,使用原始文本进行翻译")
            self.on_translate(sentence_id, original_text)
        
        # 清空缓冲区
        buffer_size = len(self.buffer)
        self.buffer.clear()
        self.last_flush_time = time.time()
        logger.debug(f"清空缓冲区,释放 {buffer_size} 个片段")


def load_transcription_results(json_path):
    """从JSON文件加载转录结果"""
    import json
    with open(json_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    results = []
    for segment in data['segments']:
        result = TranscriptionResult(
            text=segment['text'],
            start_time=segment['start_time'],
            end_time=segment['end_time'],
            confidence=segment['confidence'],
            verified=segment['verified'],
            verified_text=segment['verified_text'],
            verification_notes=segment['verification_notes'],
            segment_index=segment['segment_index'] if 'segment_index' in segment else None
        )
        results.append(result)
    
    return results

if __name__ == "__main__":
    """测试聚合器功能"""
    import os
    import sys
    import json
    from pathlib import Path
    
    # 配置日志级别
    logger.setLevel(logging.DEBUG)
    
    # 检查OpenAI API密钥
    if not os.getenv("OPENAI_API_KEY"):
        logger.warning("未设置OPENAI_API_KEY环境变量,句子完整性判断将使用备用方法")
    
    # 初始化各个模块
    from display.display import OutputRenderer
    from translator.translator import NLLBTranslator
    from transcribe.transcribe import AudioTranscriber
    
    # 初始化显示器
    renderer = OutputRenderer()
    
    # 初始化转录器
    try:
        transcriber = AudioTranscriber(model="small", device="cuda", compute_type="int8")
        logger.info("使用GPU进行转录")
    except Exception as e:
        logger.warning(f"GPU初始化失败,使用CPU: {str(e)}")
        transcriber = AudioTranscriber(model="small", device="cpu", compute_type="float32")
    
    # 初始化翻译器(可选)
    try:
        translator = NLLBTranslator()
        translation_enabled = True
    except Exception as e:
        logger.warning(f"翻译器初始化失败: {str(e)}")
        translation_enabled = False
    
    # 回调函数
    def display_callback(sentence_id, text, state):
        renderer.display(sentence_id, text, state)
        
    def translate_callback(sentence_id, text):
        if translation_enabled:
            try:
                translation = translator.translate(text)
                logger.info(f"[翻译] 句子 {sentence_id}: {translation}")
            except Exception as e:
                logger.error(f"翻译失败: {str(e)}")
        else:
            logger.info(f"[翻译已禁用] 句子 {sentence_id}: {text}")
    
    # 初始化聚合器
    aggregator = SemanticAggregator(
        on_display=display_callback,
        on_translate=translate_callback,
        transcriber=transcriber,
        segments_dir="dataset/audio/segments",
        max_window=10.0,  # 增大窗口以便测试
        max_segments=10,  # 增大片段数以便测试
        force_flush_timeout=5.0  # 增大超时以便测试
    )
    
    # 加载测试数据
    test_file = "dataset/transcripts/test1_segment_1_20250423_201934.json"
    try:
        results = load_transcription_results(test_file)
        logger.info(f"加载了 {len(results)} 条转录结果")
    except Exception as e:
        logger.error(f"加载转录结果失败: {str(e)}")
        sys.exit(1)
    
    # 模拟添加转录结果
    for i, result in enumerate(results):
        logger.info(f"添加第 {i+1}/{len(results)} 条转录结果: {result.text}")
        aggregator.add_segment(result)
        # 模拟处理延迟
        # time.sleep(0.5)
    
    # 强制刷新缓冲区
    aggregator.flush(force=True)
    
    logger.info("测试完成")