from faster_whisper import WhisperModel import json from typing import List, Optional, Union from dataclasses import dataclass import os from datetime import datetime import numpy as np import soundfile as sf import time import logging @dataclass class TranscriptionResult: text: str start_time: float end_time: float confidence: float verified: bool = False verified_text: Optional[str] = None verification_notes: Optional[str] = None segment_index: Optional[int] = None # 添加片段索引字段 # 配置日志 def setup_logger(name, level=logging.INFO): """设置日志记录器""" logger = logging.getLogger(name) # 清除所有已有的handler,避免重复 if logger.handlers: logger.handlers.clear() # 添加新的handler handler = logging.StreamHandler() formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') handler.setFormatter(formatter) logger.addHandler(handler) logger.setLevel(level) # 禁止传播到父logger,避免重复日志 logger.propagate = False return logger # 创建日志记录器 logger = setup_logger("transcribe") class AudioTranscriber: def __init__(self, model: str = "medium", device: str = "cuda", compute_type: str = "int8", log_level: Union[int, str] = logging.INFO): """ 初始化转录器 Args: model: Whisper模型名称 device: 使用的设备 ("cpu" 或 "cuda") compute_type: 计算类型 log_level: 日志级别,可以是logging模块的常量(DEBUG, INFO等)或对应的字符串 """ # 设置日志级别 if isinstance(log_level, str): log_level = getattr(logging, log_level.upper()) logger.setLevel(log_level) logger.debug(f"初始化转录器: model={model}, device={device}, compute_type={compute_type}") from faster_whisper import WhisperModel self.model = WhisperModel(model, device=device, compute_type=compute_type) logger.debug("Whisper模型加载完成") def transcribe_segment(self, audio_data: np.ndarray, start_time: float = 0.0) -> List[TranscriptionResult]: """ 转录单个音频片段 Args: audio_data: numpy数组格式的音频数据 start_time: 该片段在整体音频中的起始时间(秒) Returns: 转录结果列表 """ start_process_time = time.time() logger.debug(f"开始转录音频片段,长度: {len(audio_data)} 采样点 ({len(audio_data)/16000:.2f}秒)") try: segments_generator, info = self.model.transcribe(audio_data, beam_size=5) logger.debug(f"检测到语言: {info.language}, 语言概率: {info.language_probability:.2f}") segments = list(segments_generator) logger.debug(f"转录成功,片段数: {len(segments)}") if len(segments) > 0: logger.debug(f"第一个片段: {segments[0]}") results = [] for seg in segments: # 调整时间戳以匹配原始音频中的位置 absolute_start = start_time + seg.start absolute_end = start_time + seg.end result = TranscriptionResult( text=seg.text, start_time=absolute_start, end_time=absolute_end, confidence=1.0 - seg.no_speech_prob ) results.append(result) end_process_time = time.time() process_duration = end_process_time - start_process_time # 转录时间信息使用INFO级别,确保默认情况下可见 logger.info(f"转录完成,耗时: {process_duration:.2f}秒,共 {len(results)} 条结果") return results except Exception as e: # 错误信息使用ERROR级别 logger.error(f"转录出错: {str(e)}") raise def save_transcription(self, results: List[TranscriptionResult], audio_path: str, output_dir: str = "dataset/transcripts"): """ 保存转录结果到JSON文件 """ # 生成输出文件名 base_name = os.path.splitext(os.path.basename(audio_path))[0] timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") output_path = os.path.join(output_dir, f"{base_name}_{timestamp}.json") # 准备保存的数据 data = { "audio_file": audio_path, "timestamp": timestamp, "segments": [ { "text": r.text, "start_time": r.start_time, "end_time": r.end_time, "confidence": r.confidence, "verified": r.verified, "verified_text": r.verified_text, "verification_notes": r.verification_notes, "segment_index": r.segment_index # 添加片段索引到输出 } for r in results ] } # 保存到文件 os.makedirs(output_dir, exist_ok=True) with open(output_path, 'w', encoding='utf-8') as f: json.dump(data, f, ensure_ascii=False, indent=2) logger.info(f"转录结果已保存到: {output_path}") return output_path if __name__ == "__main__": # 测试代码 audio_path = "dataset/audio/test1.wav" # 替换为实际的音频文件路径 import soundfile as sf # 设置日志级别为DEBUG以查看详细信息 logger.setLevel(logging.DEBUG) # 初始化转录器 processor = AudioTranscriber(log_level="DEBUG") try: audio_data, sample_rate = sf.read(audio_path) logger.info(f"读取音频文件: {audio_path}, 采样率: {sample_rate}Hz, 长度: {len(audio_data)}采样点") results = processor.transcribe_segment(audio_data, start_time=0.0) logger.info(f"转录结果共 {len(results)} 条:") for res in results: logger.info(f"[{res.start_time:.2f} - {res.end_time:.2f}] {res.text}") except Exception as e: logger.error(f"转录测试出错: {e}")