|
from faster_whisper import WhisperModel |
|
import json |
|
from typing import List, Optional, Union |
|
from dataclasses import dataclass |
|
import os |
|
from datetime import datetime |
|
import numpy as np |
|
import soundfile as sf |
|
import time |
|
import logging |
|
|
|
@dataclass |
|
class TranscriptionResult: |
|
text: str |
|
start_time: float |
|
end_time: float |
|
confidence: float |
|
verified: bool = False |
|
verified_text: Optional[str] = None |
|
verification_notes: Optional[str] = None |
|
segment_index: Optional[int] = None |
|
|
|
|
|
def setup_logger(name, level=logging.INFO): |
|
"""设置日志记录器""" |
|
logger = logging.getLogger(name) |
|
|
|
if logger.handlers: |
|
logger.handlers.clear() |
|
|
|
|
|
handler = logging.StreamHandler() |
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
handler.setFormatter(formatter) |
|
logger.addHandler(handler) |
|
logger.setLevel(level) |
|
|
|
logger.propagate = False |
|
return logger |
|
|
|
|
|
logger = setup_logger("transcribe") |
|
|
|
class AudioTranscriber: |
|
def __init__(self, model: str = "medium", device: str = "cuda", compute_type: str = "int8", |
|
log_level: Union[int, str] = logging.INFO): |
|
""" |
|
初始化转录器 |
|
|
|
Args: |
|
model: Whisper模型名称 |
|
device: 使用的设备 ("cpu" 或 "cuda") |
|
compute_type: 计算类型 |
|
log_level: 日志级别,可以是logging模块的常量(DEBUG, INFO等)或对应的字符串 |
|
""" |
|
|
|
if isinstance(log_level, str): |
|
log_level = getattr(logging, log_level.upper()) |
|
logger.setLevel(log_level) |
|
|
|
logger.debug(f"初始化转录器: model={model}, device={device}, compute_type={compute_type}") |
|
|
|
from faster_whisper import WhisperModel |
|
self.model = WhisperModel(model, device=device, compute_type=compute_type) |
|
|
|
logger.debug("Whisper模型加载完成") |
|
|
|
def transcribe_segment(self, audio_data: np.ndarray, start_time: float = 0.0) -> List[TranscriptionResult]: |
|
""" |
|
转录单个音频片段 |
|
|
|
Args: |
|
audio_data: numpy数组格式的音频数据 |
|
start_time: 该片段在整体音频中的起始时间(秒) |
|
|
|
Returns: |
|
转录结果列表 |
|
""" |
|
start_process_time = time.time() |
|
|
|
logger.debug(f"开始转录音频片段,长度: {len(audio_data)} 采样点 ({len(audio_data)/16000:.2f}秒)") |
|
|
|
try: |
|
segments_generator, info = self.model.transcribe(audio_data, beam_size=5) |
|
|
|
logger.debug(f"检测到语言: {info.language}, 语言概率: {info.language_probability:.2f}") |
|
|
|
segments = list(segments_generator) |
|
|
|
logger.debug(f"转录成功,片段数: {len(segments)}") |
|
if len(segments) > 0: |
|
logger.debug(f"第一个片段: {segments[0]}") |
|
|
|
results = [] |
|
for seg in segments: |
|
|
|
absolute_start = start_time + seg.start |
|
absolute_end = start_time + seg.end |
|
|
|
result = TranscriptionResult( |
|
text=seg.text, |
|
start_time=absolute_start, |
|
end_time=absolute_end, |
|
confidence=1.0 - seg.no_speech_prob |
|
) |
|
results.append(result) |
|
|
|
end_process_time = time.time() |
|
process_duration = end_process_time - start_process_time |
|
|
|
|
|
logger.info(f"转录完成,耗时: {process_duration:.2f}秒,共 {len(results)} 条结果") |
|
|
|
return results |
|
|
|
except Exception as e: |
|
|
|
logger.error(f"转录出错: {str(e)}") |
|
raise |
|
|
|
def save_transcription(self, |
|
results: List[TranscriptionResult], |
|
audio_path: str, |
|
output_dir: str = "dataset/transcripts"): |
|
""" |
|
保存转录结果到JSON文件 |
|
""" |
|
|
|
base_name = os.path.splitext(os.path.basename(audio_path))[0] |
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
output_path = os.path.join(output_dir, f"{base_name}_{timestamp}.json") |
|
|
|
|
|
data = { |
|
"audio_file": audio_path, |
|
"timestamp": timestamp, |
|
"segments": [ |
|
{ |
|
"text": r.text, |
|
"start_time": r.start_time, |
|
"end_time": r.end_time, |
|
"confidence": r.confidence, |
|
"verified": r.verified, |
|
"verified_text": r.verified_text, |
|
"verification_notes": r.verification_notes, |
|
"segment_index": r.segment_index |
|
} |
|
for r in results |
|
] |
|
} |
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
with open(output_path, 'w', encoding='utf-8') as f: |
|
json.dump(data, f, ensure_ascii=False, indent=2) |
|
|
|
logger.info(f"转录结果已保存到: {output_path}") |
|
return output_path |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
audio_path = "dataset/audio/test1.wav" |
|
import soundfile as sf |
|
|
|
|
|
logger.setLevel(logging.DEBUG) |
|
|
|
|
|
processor = AudioTranscriber(log_level="DEBUG") |
|
|
|
try: |
|
audio_data, sample_rate = sf.read(audio_path) |
|
logger.info(f"读取音频文件: {audio_path}, 采样率: {sample_rate}Hz, 长度: {len(audio_data)}采样点") |
|
|
|
results = processor.transcribe_segment(audio_data, start_time=0.0) |
|
|
|
logger.info(f"转录结果共 {len(results)} 条:") |
|
for res in results: |
|
logger.info(f"[{res.start_time:.2f} - {res.end_time:.2f}] {res.text}") |
|
except Exception as e: |
|
logger.error(f"转录测试出错: {e}") |
|
|