File size: 6,632 Bytes
7708082
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bf36cc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7708082
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bf36cc
7708082
 
 
 
1bf36cc
7708082
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bf36cc
7708082
1bf36cc
7708082
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1bf36cc
7708082
 
 
 
 
 
 
 
1bf36cc
 
7708082
1bf36cc
 
7708082
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
from faster_whisper import WhisperModel
import json
from typing import List, Optional, Union
from dataclasses import dataclass
import os
from datetime import datetime
import numpy as np
import soundfile as sf
import time
import logging

@dataclass
class TranscriptionResult:
    text: str
    start_time: float
    end_time: float
    confidence: float
    verified: bool = False
    verified_text: Optional[str] = None
    verification_notes: Optional[str] = None
    segment_index: Optional[int] = None  # 添加片段索引字段

# 配置日志
def setup_logger(name, level=logging.INFO):
    """设置日志记录器"""
    logger = logging.getLogger(name)
    # 清除所有已有的handler,避免重复
    if logger.handlers:
        logger.handlers.clear()
    
    # 添加新的handler
    handler = logging.StreamHandler()
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    logger.setLevel(level)
    # 禁止传播到父logger,避免重复日志
    logger.propagate = False
    return logger

# 创建日志记录器
logger = setup_logger("transcribe")

class AudioTranscriber:
    def __init__(self, model: str = "medium", device: str = "cuda", compute_type: str = "int8", 
                 log_level: Union[int, str] = logging.INFO):
        """
        初始化转录器
        
        Args:
            model: Whisper模型名称
            device: 使用的设备 ("cpu" 或 "cuda")
            compute_type: 计算类型
            log_level: 日志级别,可以是logging模块的常量(DEBUG, INFO等)或对应的字符串
        """
        # 设置日志级别
        if isinstance(log_level, str):
            log_level = getattr(logging, log_level.upper())
        logger.setLevel(log_level)
        
        logger.debug(f"初始化转录器: model={model}, device={device}, compute_type={compute_type}")
            
        from faster_whisper import WhisperModel
        self.model = WhisperModel(model, device=device, compute_type=compute_type)
        
        logger.debug("Whisper模型加载完成")

    def transcribe_segment(self, audio_data: np.ndarray, start_time: float = 0.0) -> List[TranscriptionResult]:
        """
        转录单个音频片段
        
        Args:
            audio_data: numpy数组格式的音频数据
            start_time: 该片段在整体音频中的起始时间(秒)
        
        Returns:
            转录结果列表
        """
        start_process_time = time.time()
        
        logger.debug(f"开始转录音频片段,长度: {len(audio_data)} 采样点 ({len(audio_data)/16000:.2f}秒)")
            
        try:
            segments_generator, info = self.model.transcribe(audio_data, beam_size=5)
            
            logger.debug(f"检测到语言: {info.language}, 语言概率: {info.language_probability:.2f}")
                
            segments = list(segments_generator)
            
            logger.debug(f"转录成功,片段数: {len(segments)}")
            if len(segments) > 0:
                logger.debug(f"第一个片段: {segments[0]}")
                    
            results = []
            for seg in segments:
                # 调整时间戳以匹配原始音频中的位置
                absolute_start = start_time + seg.start
                absolute_end = start_time + seg.end
                
                result = TranscriptionResult(
                    text=seg.text,
                    start_time=absolute_start,
                    end_time=absolute_end,
                    confidence=1.0 - seg.no_speech_prob 
                )
                results.append(result)
                
            end_process_time = time.time()
            process_duration = end_process_time - start_process_time
            
            # 转录时间信息使用INFO级别,确保默认情况下可见
            logger.info(f"转录完成,耗时: {process_duration:.2f}秒,共 {len(results)} 条结果")
            
            return results
            
        except Exception as e:
            # 错误信息使用ERROR级别
            logger.error(f"转录出错: {str(e)}")
            raise

    def save_transcription(self, 
                         results: List[TranscriptionResult], 
                         audio_path: str,
                         output_dir: str = "dataset/transcripts"):
        """
        保存转录结果到JSON文件
        """
        # 生成输出文件名
        base_name = os.path.splitext(os.path.basename(audio_path))[0]
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        output_path = os.path.join(output_dir, f"{base_name}_{timestamp}.json")
        
        # 准备保存的数据
        data = {
            "audio_file": audio_path,
            "timestamp": timestamp,
            "segments": [
                {
                    "text": r.text,
                    "start_time": r.start_time,
                    "end_time": r.end_time,
                    "confidence": r.confidence,
                    "verified": r.verified,
                    "verified_text": r.verified_text,
                    "verification_notes": r.verification_notes,
                    "segment_index": r.segment_index  # 添加片段索引到输出
                }
                for r in results
            ]
        }
        
        # 保存到文件
        os.makedirs(output_dir, exist_ok=True)
        with open(output_path, 'w', encoding='utf-8') as f:
            json.dump(data, f, ensure_ascii=False, indent=2)
            
        logger.info(f"转录结果已保存到: {output_path}")
        return output_path


if __name__ == "__main__":
    # 测试代码
    audio_path = "dataset/audio/test1.wav"  # 替换为实际的音频文件路径
    import soundfile as sf
    
    # 设置日志级别为DEBUG以查看详细信息
    logger.setLevel(logging.DEBUG)
    
    # 初始化转录器
    processor = AudioTranscriber(log_level="DEBUG")
    
    try:
        audio_data, sample_rate = sf.read(audio_path)
        logger.info(f"读取音频文件: {audio_path}, 采样率: {sample_rate}Hz, 长度: {len(audio_data)}采样点")
        
        results = processor.transcribe_segment(audio_data, start_time=0.0)
        
        logger.info(f"转录结果共 {len(results)} 条:")
        for res in results:
            logger.info(f"[{res.start_time:.2f} - {res.end_time:.2f}] {res.text}")
    except Exception as e:
        logger.error(f"转录测试出错: {e}")