|
import sys |
|
import numpy as np |
|
import soundfile as sf |
|
from typing import List, Tuple, Optional, Dict |
|
import webrtcvad |
|
from dataclasses import dataclass, asdict |
|
from scipy import signal |
|
import json |
|
import os |
|
from datetime import datetime |
|
|
|
@dataclass |
|
class AudioSegment: |
|
start_time: float |
|
end_time: float |
|
audio_data: np.ndarray |
|
is_speech: bool |
|
|
|
class AudioProcessor: |
|
def __init__(self, |
|
sample_rate: int = 16000, |
|
frame_duration_ms: int = 20, |
|
vad_level: int = 0, |
|
min_silence_duration: float = 0.3, |
|
min_speech_duration: float = 0.3, |
|
amplitude_threshold: float = 0.0015): |
|
""" |
|
初始化音频处理器 |
|
|
|
Args: |
|
sample_rate: 采样率 |
|
frame_duration_ms: VAD帧长度(毫秒) |
|
vad_level: VAD灵敏度 (0-3) |
|
min_silence_duration: 最小静音持续时间(秒) |
|
min_speech_duration: 最小语音片段长度(秒) |
|
amplitude_threshold: 振幅阈值 |
|
""" |
|
self.sample_rate = sample_rate |
|
self.frame_duration_ms = frame_duration_ms |
|
self.frame_size = int(sample_rate * frame_duration_ms / 1000) |
|
self.vad = webrtcvad.Vad(vad_level) |
|
self.min_silence_frames = int(min_silence_duration * 1000 / frame_duration_ms) |
|
self.min_speech_frames = int(min_speech_duration * 1000 / frame_duration_ms) |
|
self.amplitude_threshold = amplitude_threshold |
|
|
|
def _is_speech_frame(self, frame: np.ndarray) -> bool: |
|
""" |
|
判断一帧是否包含语音 |
|
""" |
|
|
|
if len(frame) != self.frame_size: |
|
return False |
|
|
|
|
|
frame_int16 = np.clip(frame * 32768, -32768, 32767).astype(np.int16) |
|
|
|
|
|
frame_amplitude = np.max(np.abs(frame)) |
|
if frame_amplitude < self.amplitude_threshold: |
|
return False |
|
|
|
|
|
try: |
|
return self.vad.is_speech(frame_int16.tobytes(), self.sample_rate) |
|
except Exception as e: |
|
print(f"VAD处理出错: {e}") |
|
|
|
return frame_amplitude >= self.amplitude_threshold * 2 |
|
|
|
def process_audio_file(self, audio_path: str) -> List[AudioSegment]: |
|
""" |
|
处理音频文件,返回切割后的片段列表 |
|
""" |
|
|
|
print(f"正在读取音频文件: {audio_path}") |
|
audio_data, sample_rate = sf.read(audio_path) |
|
print(f"音频采样率: {sample_rate}Hz, 形状: {audio_data.shape}") |
|
|
|
if sample_rate != self.sample_rate: |
|
print(f"正在重采样音频从 {sample_rate}Hz 到 {self.sample_rate}Hz") |
|
|
|
num_samples = int(len(audio_data) * self.sample_rate / sample_rate) |
|
audio_data = signal.resample(audio_data, num_samples) |
|
print(f"重采样后音频长度: {len(audio_data)} 采样点") |
|
|
|
if len(audio_data.shape) > 1: |
|
print("检测到多声道音频,正在转换为单声道") |
|
audio_data = audio_data.mean(axis=1) |
|
|
|
|
|
segments: List[AudioSegment] = [] |
|
print(f"开始处理音频,总长度: {len(audio_data)} 采样点 ({len(audio_data)/self.sample_rate:.2f}秒)") |
|
|
|
|
|
current_segment_start = 0 |
|
silence_frame_count = 0 |
|
is_in_speech = False |
|
|
|
|
|
total_frames = len(audio_data) // self.frame_size |
|
speech_frames = 0 |
|
for i in range(0, len(audio_data), self.frame_size): |
|
|
|
frame = audio_data[i:i + self.frame_size] |
|
if len(frame) < self.frame_size: |
|
|
|
frame = np.pad(frame, (0, self.frame_size - len(frame)), 'constant') |
|
|
|
is_speech = self._is_speech_frame(frame) |
|
if is_speech: |
|
speech_frames += 1 |
|
|
|
if is_speech and not is_in_speech: |
|
|
|
current_segment_start = i |
|
is_in_speech = True |
|
silence_frame_count = 0 |
|
print(f"\n检测到语音开始,位置: {i/self.sample_rate:.2f}秒") |
|
elif not is_speech and is_in_speech: |
|
silence_frame_count += 1 |
|
|
|
|
|
if silence_frame_count >= self.min_silence_frames: |
|
segment_end = i - (silence_frame_count * self.frame_size) |
|
duration_frames = (segment_end - current_segment_start) // self.frame_size |
|
|
|
|
|
if duration_frames >= self.min_speech_frames: |
|
start_time = current_segment_start / self.sample_rate |
|
end_time = segment_end / self.sample_rate |
|
print(f"保存语音片段: {start_time:.2f}s -> {end_time:.2f}s (持续时间: {end_time-start_time:.2f}s)") |
|
segments.append(AudioSegment( |
|
start_time=start_time, |
|
end_time=end_time, |
|
audio_data=audio_data[current_segment_start:segment_end], |
|
is_speech=True |
|
)) |
|
else: |
|
print(f"丢弃过短的语音片段: {duration_frames * self.frame_duration_ms / 1000:.2f}s") |
|
|
|
is_in_speech = False |
|
|
|
|
|
if is_in_speech: |
|
segment_end = len(audio_data) |
|
duration_frames = (segment_end - current_segment_start) // self.frame_size |
|
if duration_frames >= self.min_speech_frames: |
|
start_time = current_segment_start / self.sample_rate |
|
end_time = segment_end / self.sample_rate |
|
print(f"保存最后的语音片段: {start_time:.2f}s -> {end_time:.2f}s (持续时间: {end_time-start_time:.2f}s)") |
|
segments.append(AudioSegment( |
|
start_time=start_time, |
|
end_time=end_time, |
|
audio_data=audio_data[current_segment_start:segment_end], |
|
is_speech=True |
|
)) |
|
else: |
|
print(f"丢弃过短的最后语音片段: {duration_frames * self.frame_duration_ms / 1000:.2f}s") |
|
|
|
print(f"\n音频处理完成:") |
|
print(f"总帧数: {total_frames}") |
|
print(f"语音帧数: {speech_frames}") |
|
print(f"检测到的语音片段数: {len(segments)}") |
|
|
|
|
|
|
|
|
|
|
|
temp_dir = "../dataset/audio/temp" |
|
os.makedirs(temp_dir, exist_ok=True) |
|
|
|
|
|
temp_data = { |
|
"audio_file": audio_path, |
|
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"), |
|
"total_frames": total_frames, |
|
"speech_frames": speech_frames, |
|
"segments": [ |
|
{ |
|
"start_time": seg.start_time, |
|
"end_time": seg.end_time, |
|
"duration": seg.end_time - seg.start_time, |
|
"is_speech": seg.is_speech |
|
} |
|
for seg in segments |
|
] |
|
} |
|
|
|
|
|
base_name = os.path.splitext(os.path.basename(audio_path))[0] |
|
temp_path = os.path.join(temp_dir, f"{base_name}_segments_{temp_data['timestamp']}.json") |
|
with open(temp_path, 'w', encoding='utf-8') as f: |
|
json.dump(temp_data, f, ensure_ascii=False, indent=2) |
|
print(f"\n临时结果已保存到: {temp_path}") |
|
|
|
return segments |
|
|
|
def save_segment(self, segment: AudioSegment, output_path: str): |
|
""" |
|
保存音频片段到文件 |
|
""" |
|
sf.write(output_path, segment.audio_data, self.sample_rate) |
|
|
|
if __name__ == "__main__": |
|
|
|
processor = AudioProcessor() |
|
|
|
|
|
audio_path = "dataset/audio/test.wav" |
|
try: |
|
segments = processor.process_audio_file(audio_path) |
|
print(f"检测到 {len(segments)} 个语音片段:") |
|
for i, segment in enumerate(segments): |
|
print(f"片段 {i+1}: {segment.start_time:.2f}s -> {segment.end_time:.2f}s") |
|
except Exception as e: |
|
print(f"处理音频时出错: {e}") |
|
|