File size: 6,632 Bytes
7708082 1bf36cc 7708082 1bf36cc 7708082 1bf36cc 7708082 1bf36cc 7708082 1bf36cc 7708082 1bf36cc 7708082 1bf36cc 7708082 1bf36cc 7708082 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
from faster_whisper import WhisperModel
import json
from typing import List, Optional, Union
from dataclasses import dataclass
import os
from datetime import datetime
import numpy as np
import soundfile as sf
import time
import logging
@dataclass
class TranscriptionResult:
text: str
start_time: float
end_time: float
confidence: float
verified: bool = False
verified_text: Optional[str] = None
verification_notes: Optional[str] = None
segment_index: Optional[int] = None # 添加片段索引字段
# 配置日志
def setup_logger(name, level=logging.INFO):
"""设置日志记录器"""
logger = logging.getLogger(name)
# 清除所有已有的handler,避免重复
if logger.handlers:
logger.handlers.clear()
# 添加新的handler
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(level)
# 禁止传播到父logger,避免重复日志
logger.propagate = False
return logger
# 创建日志记录器
logger = setup_logger("transcribe")
class AudioTranscriber:
def __init__(self, model: str = "medium", device: str = "cuda", compute_type: str = "int8",
log_level: Union[int, str] = logging.INFO):
"""
初始化转录器
Args:
model: Whisper模型名称
device: 使用的设备 ("cpu" 或 "cuda")
compute_type: 计算类型
log_level: 日志级别,可以是logging模块的常量(DEBUG, INFO等)或对应的字符串
"""
# 设置日志级别
if isinstance(log_level, str):
log_level = getattr(logging, log_level.upper())
logger.setLevel(log_level)
logger.debug(f"初始化转录器: model={model}, device={device}, compute_type={compute_type}")
from faster_whisper import WhisperModel
self.model = WhisperModel(model, device=device, compute_type=compute_type)
logger.debug("Whisper模型加载完成")
def transcribe_segment(self, audio_data: np.ndarray, start_time: float = 0.0) -> List[TranscriptionResult]:
"""
转录单个音频片段
Args:
audio_data: numpy数组格式的音频数据
start_time: 该片段在整体音频中的起始时间(秒)
Returns:
转录结果列表
"""
start_process_time = time.time()
logger.debug(f"开始转录音频片段,长度: {len(audio_data)} 采样点 ({len(audio_data)/16000:.2f}秒)")
try:
segments_generator, info = self.model.transcribe(audio_data, beam_size=5)
logger.debug(f"检测到语言: {info.language}, 语言概率: {info.language_probability:.2f}")
segments = list(segments_generator)
logger.debug(f"转录成功,片段数: {len(segments)}")
if len(segments) > 0:
logger.debug(f"第一个片段: {segments[0]}")
results = []
for seg in segments:
# 调整时间戳以匹配原始音频中的位置
absolute_start = start_time + seg.start
absolute_end = start_time + seg.end
result = TranscriptionResult(
text=seg.text,
start_time=absolute_start,
end_time=absolute_end,
confidence=1.0 - seg.no_speech_prob
)
results.append(result)
end_process_time = time.time()
process_duration = end_process_time - start_process_time
# 转录时间信息使用INFO级别,确保默认情况下可见
logger.info(f"转录完成,耗时: {process_duration:.2f}秒,共 {len(results)} 条结果")
return results
except Exception as e:
# 错误信息使用ERROR级别
logger.error(f"转录出错: {str(e)}")
raise
def save_transcription(self,
results: List[TranscriptionResult],
audio_path: str,
output_dir: str = "dataset/transcripts"):
"""
保存转录结果到JSON文件
"""
# 生成输出文件名
base_name = os.path.splitext(os.path.basename(audio_path))[0]
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_path = os.path.join(output_dir, f"{base_name}_{timestamp}.json")
# 准备保存的数据
data = {
"audio_file": audio_path,
"timestamp": timestamp,
"segments": [
{
"text": r.text,
"start_time": r.start_time,
"end_time": r.end_time,
"confidence": r.confidence,
"verified": r.verified,
"verified_text": r.verified_text,
"verification_notes": r.verification_notes,
"segment_index": r.segment_index # 添加片段索引到输出
}
for r in results
]
}
# 保存到文件
os.makedirs(output_dir, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
logger.info(f"转录结果已保存到: {output_path}")
return output_path
if __name__ == "__main__":
# 测试代码
audio_path = "dataset/audio/test1.wav" # 替换为实际的音频文件路径
import soundfile as sf
# 设置日志级别为DEBUG以查看详细信息
logger.setLevel(logging.DEBUG)
# 初始化转录器
processor = AudioTranscriber(log_level="DEBUG")
try:
audio_data, sample_rate = sf.read(audio_path)
logger.info(f"读取音频文件: {audio_path}, 采样率: {sample_rate}Hz, 长度: {len(audio_data)}采样点")
results = processor.transcribe_segment(audio_data, start_time=0.0)
logger.info(f"转录结果共 {len(results)} 条:")
for res in results:
logger.info(f"[{res.start_time:.2f} - {res.end_time:.2f}] {res.text}")
except Exception as e:
logger.error(f"转录测试出错: {e}")
|