tianyaogavin commited on
Commit
7708082
·
1 Parent(s): 25666e3

add transcribe

Browse files
Files changed (1) hide show
  1. transcribe/transcribe.py +169 -0
transcribe/transcribe.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from faster_whisper import WhisperModel
2
+ import json
3
+ from typing import List, Optional, Union
4
+ from dataclasses import dataclass
5
+ import os
6
+ from datetime import datetime
7
+ import numpy as np
8
+ import soundfile as sf
9
+ import time
10
+ import logging
11
+
12
+ @dataclass
13
+ class TranscriptionResult:
14
+ text: str
15
+ start_time: float
16
+ end_time: float
17
+ confidence: float
18
+ verified: bool = False
19
+ verified_text: Optional[str] = None
20
+ verification_notes: Optional[str] = None
21
+ segment_index: Optional[int] = None # 添加片段索引字段
22
+
23
+ # 配置日志
24
+ logger = logging.getLogger("transcribe")
25
+ handler = logging.StreamHandler()
26
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
27
+ handler.setFormatter(formatter)
28
+ logger.addHandler(handler)
29
+
30
+ class AudioTranscriber:
31
+ def __init__(self, model: str = "medium", device: str = "cuda", compute_type: str = "int8",
32
+ log_level: Union[int, str] = logging.INFO):
33
+ """
34
+ 初始化转录器
35
+
36
+ Args:
37
+ model: Whisper模型名称
38
+ device: 使用的设备 ("cpu" 或 "cuda")
39
+ compute_type: 计算类型
40
+ log_level: 日志级别,可以是logging模块的常量(DEBUG, INFO等)或对应的字符串
41
+ """
42
+ # 设置日志级别
43
+ if isinstance(log_level, str):
44
+ log_level = getattr(logging, log_level.upper())
45
+ logger.setLevel(log_level)
46
+
47
+ logger.debug("📥 Loading Whisper model...")
48
+
49
+ from faster_whisper import WhisperModel
50
+ self.model = WhisperModel(model, device=device, compute_type=compute_type)
51
+
52
+ logger.debug("📥 Loading Whisper model successfully!!")
53
+
54
+ def transcribe_segment(self, audio_data: np.ndarray, start_time: float = 0.0) -> List[TranscriptionResult]:
55
+ """
56
+ 转录单个音频片段
57
+
58
+ Args:
59
+ audio_data: numpy数组格式的音频数据
60
+ start_time: 该片段在整体音频中的起始时间(秒)
61
+
62
+ Returns:
63
+ 转录结果列表
64
+ """
65
+ start_process_time = time.time()
66
+
67
+ logger.debug("Model transcribe...")
68
+ logger.debug(f"开始转录音频片段,长度: {len(audio_data)} 采样点 ({len(audio_data)/16000:.2f}秒)")
69
+
70
+ try:
71
+ segments_generator, info = self.model.transcribe(audio_data, beam_size=5)
72
+
73
+ logger.debug(f"检测到语言: {info.language}, 语言概率: {info.language_probability:.2f}")
74
+
75
+ segments = list(segments_generator)
76
+
77
+ logger.debug(f"Model transcribe successfully! Segments count: {len(segments)}")
78
+ if len(segments) > 0:
79
+ logger.debug(f"First segment: {segments[0]}")
80
+
81
+ results = []
82
+ for seg in segments:
83
+ # 调整时间戳以匹配原始音频中的位置
84
+ absolute_start = start_time + seg.start
85
+ absolute_end = start_time + seg.end
86
+
87
+ result = TranscriptionResult(
88
+ text=seg.text,
89
+ start_time=absolute_start,
90
+ end_time=absolute_end,
91
+ confidence=1.0 - seg.no_speech_prob
92
+ )
93
+ results.append(result)
94
+
95
+ end_process_time = time.time()
96
+ process_duration = end_process_time - start_process_time
97
+
98
+ # 转录时间信息使用INFO级别,确保默认情况下可见
99
+ logger.info(f"转录完成,耗时: {process_duration:.2f}秒,共 {len(results)} 条结果")
100
+
101
+ return results
102
+
103
+ except Exception as e:
104
+ # 错误信息使用ERROR级别
105
+ logger.error(f"转录出错: {str(e)}")
106
+ raise
107
+
108
+ def save_transcription(self,
109
+ results: List[TranscriptionResult],
110
+ audio_path: str,
111
+ output_dir: str = "dataset/transcripts"):
112
+ """
113
+ 保存转录结果到JSON文件
114
+ """
115
+ # 生成输出文件名
116
+ base_name = os.path.splitext(os.path.basename(audio_path))[0]
117
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
118
+ output_path = os.path.join(output_dir, f"{base_name}_{timestamp}.json")
119
+
120
+ # 准备保存的数据
121
+ data = {
122
+ "audio_file": audio_path,
123
+ "timestamp": timestamp,
124
+ "segments": [
125
+ {
126
+ "text": r.text,
127
+ "start_time": r.start_time,
128
+ "end_time": r.end_time,
129
+ "confidence": r.confidence,
130
+ "verified": r.verified,
131
+ "verified_text": r.verified_text,
132
+ "verification_notes": r.verification_notes,
133
+ "segment_index": r.segment_index # 添加片段索引到输出
134
+ }
135
+ for r in results
136
+ ]
137
+ }
138
+
139
+ # 保存到文件
140
+ os.makedirs(output_dir, exist_ok=True)
141
+ with open(output_path, 'w', encoding='utf-8') as f:
142
+ json.dump(data, f, ensure_ascii=False, indent=2)
143
+
144
+ return output_path
145
+
146
+
147
+ if __name__ == "__main__":
148
+ # 测试代码
149
+ audio_path = "dataset/audio/test1.wav" # 替换为实际的音频文件路径
150
+ import soundfile as sf
151
+
152
+ # 设置日志级别: DEBUG, INFO, WARNING, ERROR, CRITICAL
153
+ # 可以通过字符串或常量设置
154
+ processor = AudioTranscriber(log_level="DEBUG") # 或 log_level=logging.INFO
155
+
156
+ # 也可以直接设置logger级别
157
+ # logger.setLevel(logging.DEBUG) # 查看所有详细日志
158
+
159
+ try:
160
+ audio_data, sample_rate = sf.read(audio_path)
161
+ logger.info(f"读取音频文件: {audio_path}, 采样率: {sample_rate}Hz, 长度: {len(audio_data)}采样点")
162
+
163
+ results = processor.transcribe_segment(audio_data, start_time=0.0)
164
+
165
+ logger.info(f"转录结果共 {len(results)} 条:")
166
+ for res in results:
167
+ logger.info(f"[{res.start_time:.2f} - {res.end_time:.2f}] {res.text}")
168
+ except Exception as e:
169
+ logger.error(f"转录测试出错: {e}")