Commit
·
7708082
1
Parent(s):
25666e3
add transcribe
Browse files- transcribe/transcribe.py +169 -0
transcribe/transcribe.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from faster_whisper import WhisperModel
|
2 |
+
import json
|
3 |
+
from typing import List, Optional, Union
|
4 |
+
from dataclasses import dataclass
|
5 |
+
import os
|
6 |
+
from datetime import datetime
|
7 |
+
import numpy as np
|
8 |
+
import soundfile as sf
|
9 |
+
import time
|
10 |
+
import logging
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class TranscriptionResult:
|
14 |
+
text: str
|
15 |
+
start_time: float
|
16 |
+
end_time: float
|
17 |
+
confidence: float
|
18 |
+
verified: bool = False
|
19 |
+
verified_text: Optional[str] = None
|
20 |
+
verification_notes: Optional[str] = None
|
21 |
+
segment_index: Optional[int] = None # 添加片段索引字段
|
22 |
+
|
23 |
+
# 配置日志
|
24 |
+
logger = logging.getLogger("transcribe")
|
25 |
+
handler = logging.StreamHandler()
|
26 |
+
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
27 |
+
handler.setFormatter(formatter)
|
28 |
+
logger.addHandler(handler)
|
29 |
+
|
30 |
+
class AudioTranscriber:
|
31 |
+
def __init__(self, model: str = "medium", device: str = "cuda", compute_type: str = "int8",
|
32 |
+
log_level: Union[int, str] = logging.INFO):
|
33 |
+
"""
|
34 |
+
初始化转录器
|
35 |
+
|
36 |
+
Args:
|
37 |
+
model: Whisper模型名称
|
38 |
+
device: 使用的设备 ("cpu" 或 "cuda")
|
39 |
+
compute_type: 计算类型
|
40 |
+
log_level: 日志级别,可以是logging模块的常量(DEBUG, INFO等)或对应的字符串
|
41 |
+
"""
|
42 |
+
# 设置日志级别
|
43 |
+
if isinstance(log_level, str):
|
44 |
+
log_level = getattr(logging, log_level.upper())
|
45 |
+
logger.setLevel(log_level)
|
46 |
+
|
47 |
+
logger.debug("📥 Loading Whisper model...")
|
48 |
+
|
49 |
+
from faster_whisper import WhisperModel
|
50 |
+
self.model = WhisperModel(model, device=device, compute_type=compute_type)
|
51 |
+
|
52 |
+
logger.debug("📥 Loading Whisper model successfully!!")
|
53 |
+
|
54 |
+
def transcribe_segment(self, audio_data: np.ndarray, start_time: float = 0.0) -> List[TranscriptionResult]:
|
55 |
+
"""
|
56 |
+
转录单个音频片段
|
57 |
+
|
58 |
+
Args:
|
59 |
+
audio_data: numpy数组格式的音频数据
|
60 |
+
start_time: 该片段在整体音频中的起始时间(秒)
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
转录结果列表
|
64 |
+
"""
|
65 |
+
start_process_time = time.time()
|
66 |
+
|
67 |
+
logger.debug("Model transcribe...")
|
68 |
+
logger.debug(f"开始转录音频片段,长度: {len(audio_data)} 采样点 ({len(audio_data)/16000:.2f}秒)")
|
69 |
+
|
70 |
+
try:
|
71 |
+
segments_generator, info = self.model.transcribe(audio_data, beam_size=5)
|
72 |
+
|
73 |
+
logger.debug(f"检测到语言: {info.language}, 语言概率: {info.language_probability:.2f}")
|
74 |
+
|
75 |
+
segments = list(segments_generator)
|
76 |
+
|
77 |
+
logger.debug(f"Model transcribe successfully! Segments count: {len(segments)}")
|
78 |
+
if len(segments) > 0:
|
79 |
+
logger.debug(f"First segment: {segments[0]}")
|
80 |
+
|
81 |
+
results = []
|
82 |
+
for seg in segments:
|
83 |
+
# 调整时间戳以匹配原始音频中的位置
|
84 |
+
absolute_start = start_time + seg.start
|
85 |
+
absolute_end = start_time + seg.end
|
86 |
+
|
87 |
+
result = TranscriptionResult(
|
88 |
+
text=seg.text,
|
89 |
+
start_time=absolute_start,
|
90 |
+
end_time=absolute_end,
|
91 |
+
confidence=1.0 - seg.no_speech_prob
|
92 |
+
)
|
93 |
+
results.append(result)
|
94 |
+
|
95 |
+
end_process_time = time.time()
|
96 |
+
process_duration = end_process_time - start_process_time
|
97 |
+
|
98 |
+
# 转录时间信息使用INFO级别,确保默认情况下可见
|
99 |
+
logger.info(f"转录完成,耗时: {process_duration:.2f}秒,共 {len(results)} 条结果")
|
100 |
+
|
101 |
+
return results
|
102 |
+
|
103 |
+
except Exception as e:
|
104 |
+
# 错误信息使用ERROR级别
|
105 |
+
logger.error(f"转录出错: {str(e)}")
|
106 |
+
raise
|
107 |
+
|
108 |
+
def save_transcription(self,
|
109 |
+
results: List[TranscriptionResult],
|
110 |
+
audio_path: str,
|
111 |
+
output_dir: str = "dataset/transcripts"):
|
112 |
+
"""
|
113 |
+
保存转录结果到JSON文件
|
114 |
+
"""
|
115 |
+
# 生成输出文件名
|
116 |
+
base_name = os.path.splitext(os.path.basename(audio_path))[0]
|
117 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
118 |
+
output_path = os.path.join(output_dir, f"{base_name}_{timestamp}.json")
|
119 |
+
|
120 |
+
# 准备保存的数据
|
121 |
+
data = {
|
122 |
+
"audio_file": audio_path,
|
123 |
+
"timestamp": timestamp,
|
124 |
+
"segments": [
|
125 |
+
{
|
126 |
+
"text": r.text,
|
127 |
+
"start_time": r.start_time,
|
128 |
+
"end_time": r.end_time,
|
129 |
+
"confidence": r.confidence,
|
130 |
+
"verified": r.verified,
|
131 |
+
"verified_text": r.verified_text,
|
132 |
+
"verification_notes": r.verification_notes,
|
133 |
+
"segment_index": r.segment_index # 添加片段索引到输出
|
134 |
+
}
|
135 |
+
for r in results
|
136 |
+
]
|
137 |
+
}
|
138 |
+
|
139 |
+
# 保存到文件
|
140 |
+
os.makedirs(output_dir, exist_ok=True)
|
141 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
142 |
+
json.dump(data, f, ensure_ascii=False, indent=2)
|
143 |
+
|
144 |
+
return output_path
|
145 |
+
|
146 |
+
|
147 |
+
if __name__ == "__main__":
|
148 |
+
# 测试代码
|
149 |
+
audio_path = "dataset/audio/test1.wav" # 替换为实际的音频文件路径
|
150 |
+
import soundfile as sf
|
151 |
+
|
152 |
+
# 设置日志级别: DEBUG, INFO, WARNING, ERROR, CRITICAL
|
153 |
+
# 可以通过字符串或常量设置
|
154 |
+
processor = AudioTranscriber(log_level="DEBUG") # 或 log_level=logging.INFO
|
155 |
+
|
156 |
+
# 也可以直接设置logger级别
|
157 |
+
# logger.setLevel(logging.DEBUG) # 查看所有详细日志
|
158 |
+
|
159 |
+
try:
|
160 |
+
audio_data, sample_rate = sf.read(audio_path)
|
161 |
+
logger.info(f"读取音频文件: {audio_path}, 采样率: {sample_rate}Hz, 长度: {len(audio_data)}采样点")
|
162 |
+
|
163 |
+
results = processor.transcribe_segment(audio_data, start_time=0.0)
|
164 |
+
|
165 |
+
logger.info(f"转录结果共 {len(results)} 条:")
|
166 |
+
for res in results:
|
167 |
+
logger.info(f"[{res.start_time:.2f} - {res.end_time:.2f}] {res.text}")
|
168 |
+
except Exception as e:
|
169 |
+
logger.error(f"转录测试出错: {e}")
|