tianyaogavin's picture
vad module
238b905
from faster_whisper import WhisperModel
from audio_processor import AudioSegment
import json
from typing import List, Dict, Optional, Tuple
from dataclasses import dataclass
import os
from datetime import datetime
import numpy as np
@dataclass
class TranscriptionResult:
text: str
start_time: float
end_time: float
confidence: float
verified: bool = False
verified_text: Optional[str] = None
verification_notes: Optional[str] = None
segment_index: Optional[int] = None # 添加片段索引字段
class AudioTranscriber:
def __init__(self, model: str = "large", device: str = "cuda", compute_type: str = "int8"):
"""
初始化转录器
Args:
model_path: Whisper模型路径
device: 使用的设备 ("cpu" 或 "cuda")
compute_type: 计算类型
"""
print("📥 Loading Whisper model...")
self.model = WhisperModel(model, device=device, compute_type=compute_type)
print("📥 Loading Whisper model successfully!!")
def transcribe_segment(self, segment: AudioSegment) -> List[TranscriptionResult]:
"""
转录单个音频片段
"""
print("Model transcribe...")
print(f"开始转录音频片段,长度: {len(segment.audio_data)} 采样点 ({len(segment.audio_data)/16000:.2f}秒)")
segments_generator, info = self.model.transcribe(segment.audio_data,
beam_size=5,
language="zh")
print(f"检测到语言: {info.language}, 语言概率: {info.language_probability:.2f}")
segments = list(segments_generator)
print(f"Model transcribe successfully! Segments count: {len(segments)}")
if len(segments) > 0:
print(segments[0])
results = []
for seg in segments:
# 调整时间戳以匹配原始音频中的位置
absolute_start = segment.start_time + seg.start
absolute_end = segment.start_time + seg.end
result = TranscriptionResult(
text=seg.text,
start_time=absolute_start,
end_time=absolute_end,
confidence=1.0 - seg.no_speech_prob
)
results.append(result)
return results
def transcribe_aggregated_segments(self, segments: List[AudioSegment]) -> List[TranscriptionResult]:
"""
将多个连续的音频片段聚合后进行转录,以提高语义准确度
Args:
segments: 要聚合的连续音频片段列表
Returns:
转录结果列表
"""
if not segments:
return []
print(f"开始聚合转录 {len(segments)} 个连续片段...")
# 记录第一个片段的开始时间和最后一个片段的结束时间
start_time = segments[0].start_time
end_time = segments[-1].end_time
# 计算总长度并创建合并的音频数据数组
total_length = sum(len(segment.audio_data) for segment in segments)
aggregated_audio = np.zeros(total_length, dtype=np.float32)
# 合并音频数据
current_position = 0
for segment in segments:
segment_length = len(segment.audio_data)
aggregated_audio[current_position:current_position + segment_length] = segment.audio_data
current_position += segment_length
print(f"聚合后音频长度: {len(aggregated_audio)} 采样点 ({len(aggregated_audio)/16000:.2f}秒)")
# 创建一个临时的聚合片段对象
aggregated_segment = AudioSegment(
start_time=start_time,
end_time=end_time,
audio_data=aggregated_audio,
is_speech=True
)
# 转录聚合后的音频
print("开始转录聚合后的音频...")
return self.transcribe_segment(aggregated_segment)
def save_transcription(self,
results: List[TranscriptionResult],
audio_path: str,
output_dir: str = "dataset/transcripts"):
"""
保存转录结果到JSON文件
"""
# 生成输出文件名
base_name = os.path.splitext(os.path.basename(audio_path))[0]
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_path = os.path.join(output_dir, f"{base_name}_{timestamp}.json")
# 准备保存的数据
data = {
"audio_file": audio_path,
"timestamp": timestamp,
"segments": [
{
"text": r.text,
"start_time": r.start_time,
"end_time": r.end_time,
"confidence": r.confidence,
"verified": r.verified,
"verified_text": r.verified_text,
"verification_notes": r.verification_notes,
"segment_index": r.segment_index # 添加片段索引到输出
}
for r in results
]
}
# 保存到文件
os.makedirs(output_dir, exist_ok=True)
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
return output_path
def verify_transcription(self,
result: TranscriptionResult,
verified_text: Optional[str] = None,
verification_notes: Optional[str] = None) -> TranscriptionResult:
"""
人工验证转录结果
"""
result.verified = True
if verified_text is not None:
result.verified_text = verified_text
if verification_notes is not None:
result.verification_notes = verification_notes
return result
def load_transcription(self, json_path: str) -> List[TranscriptionResult]:
"""
从JSON文件加载转录结果
"""
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
results = []
for seg in data["segments"]:
result = TranscriptionResult(
text=seg["text"],
start_time=seg["start_time"],
end_time=seg["end_time"],
confidence=seg["confidence"],
verified=seg["verified"],
verified_text=seg.get("verified_text"),
verification_notes=seg.get("verification_notes"),
segment_index=seg.get("segment_index") # 加载片段索引
)
results.append(result)
return results
if __name__ == "__main__":
# 测试代码
from audio_processor import AudioProcessor
# 初始化处理器和转录器
processor = AudioProcessor()
transcriber = AudioTranscriber()
# 示例:处理和转录音频文件
audio_path = "../dataset/audio/test.wav" # 替换为实际的音频文件路径
try:
# 1. 切割音频
segments = processor.process_audio_file(audio_path)
print(f"检测到 {len(segments)} 个语音片段")
# 2. 转录每个片段
all_results = []
for i, segment in enumerate(segments):
print(f"转录片段 {i+1}/{len(segments)}...")
results = transcriber.transcribe_segment(segment)
all_results.extend(results)
# 3. 保存结果
output_path = transcriber.save_transcription(all_results, audio_path)
print(f"✅ 转录结果已保存到: {output_path}")
except Exception as e:
print(f"处理音频时出错: {e}")