from faster_whisper import WhisperModel from audio_processor import AudioSegment import json from typing import List, Dict, Optional, Tuple from dataclasses import dataclass import os from datetime import datetime import numpy as np @dataclass class TranscriptionResult: text: str start_time: float end_time: float confidence: float verified: bool = False verified_text: Optional[str] = None verification_notes: Optional[str] = None segment_index: Optional[int] = None # 添加片段索引字段 class AudioTranscriber: def __init__(self, model: str = "large", device: str = "cuda", compute_type: str = "int8"): """ 初始化转录器 Args: model_path: Whisper模型路径 device: 使用的设备 ("cpu" 或 "cuda") compute_type: 计算类型 """ print("📥 Loading Whisper model...") self.model = WhisperModel(model, device=device, compute_type=compute_type) print("📥 Loading Whisper model successfully!!") def transcribe_segment(self, segment: AudioSegment) -> List[TranscriptionResult]: """ 转录单个音频片段 """ print("Model transcribe...") print(f"开始转录音频片段,长度: {len(segment.audio_data)} 采样点 ({len(segment.audio_data)/16000:.2f}秒)") segments_generator, info = self.model.transcribe(segment.audio_data, beam_size=5, language="zh") print(f"检测到语言: {info.language}, 语言概率: {info.language_probability:.2f}") segments = list(segments_generator) print(f"Model transcribe successfully! Segments count: {len(segments)}") if len(segments) > 0: print(segments[0]) results = [] for seg in segments: # 调整时间戳以匹配原始音频中的位置 absolute_start = segment.start_time + seg.start absolute_end = segment.start_time + seg.end result = TranscriptionResult( text=seg.text, start_time=absolute_start, end_time=absolute_end, confidence=1.0 - seg.no_speech_prob ) results.append(result) return results def transcribe_aggregated_segments(self, segments: List[AudioSegment]) -> List[TranscriptionResult]: """ 将多个连续的音频片段聚合后进行转录,以提高语义准确度 Args: segments: 要聚合的连续音频片段列表 Returns: 转录结果列表 """ if not segments: return [] print(f"开始聚合转录 {len(segments)} 个连续片段...") # 记录第一个片段的开始时间和最后一个片段的结束时间 start_time = segments[0].start_time end_time = segments[-1].end_time # 计算总长度并创建合并的音频数据数组 total_length = sum(len(segment.audio_data) for segment in segments) aggregated_audio = np.zeros(total_length, dtype=np.float32) # 合并音频数据 current_position = 0 for segment in segments: segment_length = len(segment.audio_data) aggregated_audio[current_position:current_position + segment_length] = segment.audio_data current_position += segment_length print(f"聚合后音频长度: {len(aggregated_audio)} 采样点 ({len(aggregated_audio)/16000:.2f}秒)") # 创建一个临时的聚合片段对象 aggregated_segment = AudioSegment( start_time=start_time, end_time=end_time, audio_data=aggregated_audio, is_speech=True ) # 转录聚合后的音频 print("开始转录聚合后的音频...") return self.transcribe_segment(aggregated_segment) def save_transcription(self, results: List[TranscriptionResult], audio_path: str, output_dir: str = "dataset/transcripts"): """ 保存转录结果到JSON文件 """ # 生成输出文件名 base_name = os.path.splitext(os.path.basename(audio_path))[0] timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") output_path = os.path.join(output_dir, f"{base_name}_{timestamp}.json") # 准备保存的数据 data = { "audio_file": audio_path, "timestamp": timestamp, "segments": [ { "text": r.text, "start_time": r.start_time, "end_time": r.end_time, "confidence": r.confidence, "verified": r.verified, "verified_text": r.verified_text, "verification_notes": r.verification_notes, "segment_index": r.segment_index # 添加片段索引到输出 } for r in results ] } # 保存到文件 os.makedirs(output_dir, exist_ok=True) with open(output_path, 'w', encoding='utf-8') as f: json.dump(data, f, ensure_ascii=False, indent=2) return output_path def verify_transcription(self, result: TranscriptionResult, verified_text: Optional[str] = None, verification_notes: Optional[str] = None) -> TranscriptionResult: """ 人工验证转录结果 """ result.verified = True if verified_text is not None: result.verified_text = verified_text if verification_notes is not None: result.verification_notes = verification_notes return result def load_transcription(self, json_path: str) -> List[TranscriptionResult]: """ 从JSON文件加载转录结果 """ with open(json_path, 'r', encoding='utf-8') as f: data = json.load(f) results = [] for seg in data["segments"]: result = TranscriptionResult( text=seg["text"], start_time=seg["start_time"], end_time=seg["end_time"], confidence=seg["confidence"], verified=seg["verified"], verified_text=seg.get("verified_text"), verification_notes=seg.get("verification_notes"), segment_index=seg.get("segment_index") # 加载片段索引 ) results.append(result) return results if __name__ == "__main__": # 测试代码 from audio_processor import AudioProcessor # 初始化处理器和转录器 processor = AudioProcessor() transcriber = AudioTranscriber() # 示例:处理和转录音频文件 audio_path = "../dataset/audio/test.wav" # 替换为实际的音频文件路径 try: # 1. 切割音频 segments = processor.process_audio_file(audio_path) print(f"检测到 {len(segments)} 个语音片段") # 2. 转录每个片段 all_results = [] for i, segment in enumerate(segments): print(f"转录片段 {i+1}/{len(segments)}...") results = transcriber.transcribe_segment(segment) all_results.extend(results) # 3. 保存结果 output_path = transcriber.save_transcription(all_results, audio_path) print(f"✅ 转录结果已保存到: {output_path}") except Exception as e: print(f"处理音频时出错: {e}")