tianyaogavin's picture
init main framework
1bf36cc
from faster_whisper import WhisperModel
import json
from typing import List, Optional, Union
from dataclasses import dataclass
import os
from datetime import datetime
import numpy as np
import soundfile as sf
import time
import logging
@dataclass
class TranscriptionResult:
text: str
start_time: float
end_time: float
confidence: float
verified: bool = False
verified_text: Optional[str] = None
verification_notes: Optional[str] = None
segment_index: Optional[int] = None # 添加片段索引字段
# 配置日志
def setup_logger(name, level=logging.INFO):
"""设置日志记录器"""
logger = logging.getLogger(name)
# 清除所有已有的handler,避免重复
if logger.handlers:
logger.handlers.clear()
# 添加新的handler
handler = logging.StreamHandler()
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(level)
# 禁止传播到父logger,避免重复日志
logger.propagate = False
return logger
# 创建日志记录器
logger = setup_logger("transcribe")
class AudioTranscriber:
def __init__(self, model: str = "medium", device: str = "cuda", compute_type: str = "int8",
log_level: Union[int, str] = logging.INFO):
"""
初始化转录器
Args:
model: Whisper模型名称
device: 使用的设备 ("cpu" 或 "cuda")
compute_type: 计算类型
log_level: 日志级别,可以是logging模块的常量(DEBUG, INFO等)或对应的字符串
"""
# 设置日志级别
if isinstance(log_level, str):
log_level = getattr(logging, log_level.upper())
logger.setLevel(log_level)
logger.debug(f"初始化转录器: model={model}, device={device}, compute_type={compute_type}")
from faster_whisper import WhisperModel
self.model = WhisperModel(model, device=device, compute_type=compute_type)
logger.debug("Whisper模型加载完成")
def transcribe_segment(self, audio_data: np.ndarray, start_time: float = 0.0) -> List[TranscriptionResult]:
"""
转录单个音频片段
Args:
audio_data: numpy数组格式的音频数据
start_time: 该片段在整体音频中的起始时间(秒)
Returns:
转录结果列表
"""
start_process_time = time.time()
logger.debug(f"开始转录音频片段,长度: {len(audio_data)} 采样点 ({len(audio_data)/16000:.2f}秒)")
try:
segments_generator, info = self.model.transcribe(audio_data, beam_size=5)
logger.debug(f"检测到语言: {info.language}, 语言概率: {info.language_probability:.2f}")
segments = list(segments_generator)
logger.debug(f"转录成功,片段数: {len(segments)}")
if len(segments) > 0:
logger.debug(f"第一个片段: {segments[0]}")
results = []
for seg in segments:
# 调整时间戳以匹配原始音频中的位置
absolute_start = start_time + seg.start
absolute_end = start_time + seg.end
result = TranscriptionResult(
text=seg.text,
start_time=absolute_start,
end_time=absolute_end,
confidence=1.0 - seg.no_speech_prob
)
results.append(result)
end_process_time = time.time()
process_duration = end_process_time - start_process_time
# 转录时间信息使用INFO级别,确保默认情况下可见
logger.info(f"转录完成,耗时: {process_duration:.2f}秒,共 {len(results)} 条结果")
return results
except Exception as e:
# 错误信息使用ERROR级别
logger.error(f"转录出错: {str(e)}")
raise
def save_transcription(self,
results: List[TranscriptionResult],
audio_path: str,
output_dir: str = "dataset/transcripts"):
"""
保存转录结果到JSON文件
"""
# 生成输出文件名
base_name = os.path.splitext(os.path.basename(audio_path))[0]
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_path = os.path.join(output_dir, f"{base_name}_{timestamp}.json")
# 准备保存的数据
data = {
"audio_file": audio_path,
"timestamp": timestamp,
"segments": [
{
"text": r.text,
"start_time": r.start_time,
"end_time": r.end_time,
"confidence": r.confidence,
"verified": r.verified,
"verified_text": r.verified_text,
"verification_notes": r.verification_notes,
"segment_index": r.segment_index # 添加片段索引到输出
}
for r in results
]
}
# 保存到文件
os.makedirs(output_dir, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
logger.info(f"转录结果已保存到: {output_path}")
return output_path
if __name__ == "__main__":
# 测试代码
audio_path = "dataset/audio/test1.wav" # 替换为实际的音频文件路径
import soundfile as sf
# 设置日志级别为DEBUG以查看详细信息
logger.setLevel(logging.DEBUG)
# 初始化转录器
processor = AudioTranscriber(log_level="DEBUG")
try:
audio_data, sample_rate = sf.read(audio_path)
logger.info(f"读取音频文件: {audio_path}, 采样率: {sample_rate}Hz, 长度: {len(audio_data)}采样点")
results = processor.transcribe_segment(audio_data, start_time=0.0)
logger.info(f"转录结果共 {len(results)} 条:")
for res in results:
logger.info(f"[{res.start_time:.2f} - {res.end_time:.2f}] {res.text}")
except Exception as e:
logger.error(f"转录测试出错: {e}")