|
from faster_whisper import WhisperModel |
|
from audio_processor import AudioSegment |
|
import json |
|
from typing import List, Dict, Optional, Tuple |
|
from dataclasses import dataclass |
|
import os |
|
from datetime import datetime |
|
import numpy as np |
|
|
|
@dataclass |
|
class TranscriptionResult: |
|
text: str |
|
start_time: float |
|
end_time: float |
|
confidence: float |
|
verified: bool = False |
|
verified_text: Optional[str] = None |
|
verification_notes: Optional[str] = None |
|
segment_index: Optional[int] = None |
|
|
|
class AudioTranscriber: |
|
def __init__(self, model: str = "large", device: str = "cuda", compute_type: str = "int8"): |
|
""" |
|
初始化转录器 |
|
|
|
Args: |
|
model_path: Whisper模型路径 |
|
device: 使用的设备 ("cpu" 或 "cuda") |
|
compute_type: 计算类型 |
|
""" |
|
print("📥 Loading Whisper model...") |
|
self.model = WhisperModel(model, device=device, compute_type=compute_type) |
|
print("📥 Loading Whisper model successfully!!") |
|
|
|
def transcribe_segment(self, segment: AudioSegment) -> List[TranscriptionResult]: |
|
""" |
|
转录单个音频片段 |
|
""" |
|
print("Model transcribe...") |
|
print(f"开始转录音频片段,长度: {len(segment.audio_data)} 采样点 ({len(segment.audio_data)/16000:.2f}秒)") |
|
segments_generator, info = self.model.transcribe(segment.audio_data, |
|
beam_size=5, |
|
language="zh") |
|
print(f"检测到语言: {info.language}, 语言概率: {info.language_probability:.2f}") |
|
segments = list(segments_generator) |
|
print(f"Model transcribe successfully! Segments count: {len(segments)}") |
|
if len(segments) > 0: |
|
print(segments[0]) |
|
results = [] |
|
for seg in segments: |
|
|
|
absolute_start = segment.start_time + seg.start |
|
absolute_end = segment.start_time + seg.end |
|
|
|
result = TranscriptionResult( |
|
text=seg.text, |
|
start_time=absolute_start, |
|
end_time=absolute_end, |
|
confidence=1.0 - seg.no_speech_prob |
|
) |
|
results.append(result) |
|
|
|
return results |
|
|
|
def transcribe_aggregated_segments(self, segments: List[AudioSegment]) -> List[TranscriptionResult]: |
|
""" |
|
将多个连续的音频片段聚合后进行转录,以提高语义准确度 |
|
|
|
Args: |
|
segments: 要聚合的连续音频片段列表 |
|
|
|
Returns: |
|
转录结果列表 |
|
""" |
|
if not segments: |
|
return [] |
|
|
|
print(f"开始聚合转录 {len(segments)} 个连续片段...") |
|
|
|
|
|
start_time = segments[0].start_time |
|
end_time = segments[-1].end_time |
|
|
|
|
|
total_length = sum(len(segment.audio_data) for segment in segments) |
|
aggregated_audio = np.zeros(total_length, dtype=np.float32) |
|
|
|
|
|
current_position = 0 |
|
for segment in segments: |
|
segment_length = len(segment.audio_data) |
|
aggregated_audio[current_position:current_position + segment_length] = segment.audio_data |
|
current_position += segment_length |
|
|
|
print(f"聚合后音频长度: {len(aggregated_audio)} 采样点 ({len(aggregated_audio)/16000:.2f}秒)") |
|
|
|
|
|
aggregated_segment = AudioSegment( |
|
start_time=start_time, |
|
end_time=end_time, |
|
audio_data=aggregated_audio, |
|
is_speech=True |
|
) |
|
|
|
|
|
print("开始转录聚合后的音频...") |
|
return self.transcribe_segment(aggregated_segment) |
|
|
|
def save_transcription(self, |
|
results: List[TranscriptionResult], |
|
audio_path: str, |
|
output_dir: str = "dataset/transcripts"): |
|
""" |
|
保存转录结果到JSON文件 |
|
""" |
|
|
|
base_name = os.path.splitext(os.path.basename(audio_path))[0] |
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
output_path = os.path.join(output_dir, f"{base_name}_{timestamp}.json") |
|
|
|
|
|
data = { |
|
"audio_file": audio_path, |
|
"timestamp": timestamp, |
|
"segments": [ |
|
{ |
|
"text": r.text, |
|
"start_time": r.start_time, |
|
"end_time": r.end_time, |
|
"confidence": r.confidence, |
|
"verified": r.verified, |
|
"verified_text": r.verified_text, |
|
"verification_notes": r.verification_notes, |
|
"segment_index": r.segment_index |
|
} |
|
for r in results |
|
] |
|
} |
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
with open(output_path, 'w', encoding='utf-8') as f: |
|
json.dump(data, f, ensure_ascii=False, indent=2) |
|
|
|
return output_path |
|
|
|
def verify_transcription(self, |
|
result: TranscriptionResult, |
|
verified_text: Optional[str] = None, |
|
verification_notes: Optional[str] = None) -> TranscriptionResult: |
|
""" |
|
人工验证转录结果 |
|
""" |
|
result.verified = True |
|
if verified_text is not None: |
|
result.verified_text = verified_text |
|
if verification_notes is not None: |
|
result.verification_notes = verification_notes |
|
return result |
|
|
|
def load_transcription(self, json_path: str) -> List[TranscriptionResult]: |
|
""" |
|
从JSON文件加载转录结果 |
|
""" |
|
with open(json_path, 'r', encoding='utf-8') as f: |
|
data = json.load(f) |
|
|
|
results = [] |
|
for seg in data["segments"]: |
|
result = TranscriptionResult( |
|
text=seg["text"], |
|
start_time=seg["start_time"], |
|
end_time=seg["end_time"], |
|
confidence=seg["confidence"], |
|
verified=seg["verified"], |
|
verified_text=seg.get("verified_text"), |
|
verification_notes=seg.get("verification_notes"), |
|
segment_index=seg.get("segment_index") |
|
) |
|
results.append(result) |
|
|
|
return results |
|
|
|
if __name__ == "__main__": |
|
|
|
from audio_processor import AudioProcessor |
|
|
|
|
|
processor = AudioProcessor() |
|
transcriber = AudioTranscriber() |
|
|
|
|
|
audio_path = "../dataset/audio/test.wav" |
|
try: |
|
|
|
segments = processor.process_audio_file(audio_path) |
|
print(f"检测到 {len(segments)} 个语音片段") |
|
|
|
|
|
all_results = [] |
|
for i, segment in enumerate(segments): |
|
print(f"转录片段 {i+1}/{len(segments)}...") |
|
results = transcriber.transcribe_segment(segment) |
|
all_results.extend(results) |
|
|
|
|
|
output_path = transcriber.save_transcription(all_results, audio_path) |
|
print(f"✅ 转录结果已保存到: {output_path}") |
|
|
|
except Exception as e: |
|
print(f"处理音频时出错: {e}") |
|
|