import os import argparse from audio_processor import AudioProcessor, AudioSegment from audio_transcriber import AudioTranscriber, TranscriptionResult from typing import List, Tuple, Optional import json import soundfile as sf def process_audio(audio_path: str, output_dir: str = "dataset/audio/segments") -> Tuple[List[str], List[AudioSegment]]: """ 处理音频文件,返回切割后的音频片段路径列表和原始片段列表 """ processor = AudioProcessor() segments = processor.process_audio_file(audio_path) # 创建输出目录 os.makedirs(output_dir, exist_ok=True) # 保存切割后的音频片段 segment_paths = [] base_name = os.path.splitext(os.path.basename(audio_path))[0] for i, segment in enumerate(segments): segment_path = os.path.join(output_dir, f"{base_name}_segment_{i+1}.wav") processor.save_segment(segment, segment_path) segment_paths.append(segment_path) return segment_paths, segments def transcribe_segments(segment_paths: List[str], original_segments: List[AudioSegment]) -> str: """ 转录音频片段并保存结果 """ transcriber = AudioTranscriber() all_results = [] for i, path in enumerate(segment_paths): print(f"转录片段 {i+1}/{len(segment_paths)}: {path}") # 读取音频数据 audio_data, _ = sf.read(path) # 使用原始片段的时间戳 original_segment = original_segments[i] # 创建AudioSegment对象,保持原始时间戳 segment = AudioSegment( start_time=original_segment.start_time, end_time=original_segment.end_time, audio_data=audio_data, is_speech=True ) # 转录 results = transcriber.transcribe_segment(segment) # 设置片段索引 for result in results: result.segment_index = i all_results.extend(results) # 保存转录结果 output_path = transcriber.save_transcription(all_results, segment_paths[0]) return output_path def transcribe_aggregated_segments(segment_paths: List[str], original_segments: List[AudioSegment], segment_indices: List[int]) -> str: """ 聚合转录指定的连续音频片段并保存结果 Args: segment_paths: 所有音频片段的路径列表 original_segments: 原始音频片段列表 segment_indices: 要聚合的片段索引列表 Returns: 保存的转录结果文件路径 """ if not segment_indices: raise ValueError("必须指定至少一个片段索引") # 验证索引有效性 for idx in segment_indices: if idx < 0 or idx >= len(segment_paths): raise ValueError(f"无效的片段索引: {idx},有效范围: 0-{len(segment_paths)-1}") # 按索引排序,确保按顺序处理 segment_indices.sort() print(f"准备聚合转录片段索引: {segment_indices}") transcriber = AudioTranscriber() segments_to_aggregate = [] # 加载指定的音频片段 for idx in segment_indices: path = segment_paths[idx] print(f"加载片段 {idx}/{len(segment_paths)-1}: {path}") # 读取音频数据 audio_data, _ = sf.read(path) # 使用原始片段的时间戳 original_segment = original_segments[idx] # 创建AudioSegment对象,保持原始时间戳 segment = AudioSegment( start_time=original_segment.start_time, end_time=original_segment.end_time, audio_data=audio_data, is_speech=True ) segments_to_aggregate.append(segment) # 聚合转录 print(f"开始聚合转录 {len(segments_to_aggregate)} 个片段...") results = transcriber.transcribe_aggregated_segments(segments_to_aggregate) # 设置聚合片段的索引信息 for result in results: # 使用聚合的片段索引列表作为segment_index result.segment_index = segment_indices # 保存转录结果 # 使用第一个片段的路径作为基础,但添加"aggregated"标记 base_path = segment_paths[segment_indices[0]] base_name = os.path.splitext(os.path.basename(base_path))[0] aggregated_name = f"{base_name}_aggregated_{segment_indices[0]}_to_{segment_indices[-1]}" # 创建一个临时路径用于保存 temp_path = os.path.join(os.path.dirname(base_path), f"{aggregated_name}.wav") output_path = transcriber.save_transcription(results, temp_path) print(f"✅ 聚合转录结果已保存到: {output_path}") return output_path def verify_transcription(json_path: str): """ 交互式验证转录结果 """ transcriber = AudioTranscriber() results = transcriber.load_transcription(json_path) print("\n=== 转录结果验证 ===") print(f"加载转录文件: {json_path}") print(f"共有 {len(results)} 个片段需要验证") for i, result in enumerate(results): if result.verified: continue print(f"\n片段 {i+1}/{len(results)}") print(f"时间: {result.start_time:.2f}s -> {result.end_time:.2f}s") print(f"转录文本: {result.text}") print(f"置信度: {result.confidence:.2f}") while True: choice = input("\n选项:\n1. 确认正确\n2. 修改文本\n3. 添加注释\n4. 跳过\n请选择 (1-4): ") if choice == "1": transcriber.verify_transcription(result) break elif choice == "2": new_text = input("请输入正确的文本: ") transcriber.verify_transcription(result, verified_text=new_text) break elif choice == "3": notes = input("请输入注释: ") transcriber.verify_transcription(result, verification_notes=notes) break elif choice == "4": break # 保存验证后的结果 transcriber.save_transcription(results, json_path) print("\n✅ 验证结果已保存") def get_existing_segments(base_dir="dataset/audio/segments"): """ 获取已存在的音频片段列表 """ if not os.path.exists(base_dir): return [], [] # 获取所有wav文件 segment_files = [f for f in os.listdir(base_dir) if f.endswith('.wav')] # 创建一个列表来存储(索引, 路径)对 indexed_paths = [] # 从文件名中提取索引 for filename in segment_files: parts = filename.split('_') if len(parts) >= 3 and parts[-2] == "segment": try: segment_idx = int(parts[-1].split('.')[0]) - 1 # 转换为0-based索引 full_path = os.path.join(base_dir, filename) indexed_paths.append((segment_idx, full_path)) except (ValueError, IndexError): print(f"警告: 无法从文件名 {filename} 中提取片段索引") # 按索引排序 indexed_paths.sort(key=lambda x: x[0]) # 提取排序后的路径 segment_paths = [path for _, path in indexed_paths] # 创建临时的AudioSegment对象 original_segments = [] for idx, path in indexed_paths: audio_data, sample_rate = sf.read(path) # 简单估计时间戳(实际应用中可能需要更精确的方法) duration = len(audio_data) / sample_rate segment = AudioSegment( start_time=0.0, # 这里使用相对时间 end_time=duration, audio_data=audio_data, is_speech=True ) original_segments.append(segment) return segment_paths, original_segments def main(): parser = argparse.ArgumentParser(description="音频处理和转录工具") parser.add_argument("action", choices=["process", "verify", "aggregate"], help="执行的操作: process(处理音频), verify(验证转录), aggregate(聚合转录)") parser.add_argument("input_path", nargs='?', help="输入文件路径 (音频文件或JSON文件)") parser.add_argument("--segments", type=str, help="要聚合的片段索引,用逗号分隔,例如 '0,1,2'") args = parser.parse_args() try: if args.action == "process": if not args.input_path: print("❌ 使用 process 操作时必须指定输入文件路径") return print(f"处理音频文件: {args.input_path}") # 1. 切割音频 segment_paths, original_segments = process_audio(args.input_path) print(f"✅ 音频切割完成,共 {len(segment_paths)} 个片段") if len(segment_paths) == 0: print("❌ 未找到有效的音频片段,无法继续处理") return # 2. 转录音频 json_path = transcribe_segments(segment_paths, original_segments) print(f"✅ 转录完成,结果保存在: {json_path}") elif args.action == "verify": if not args.input_path: print("❌ 使用 verify 操作时必须指定输入文件路径") return verify_transcription(args.input_path) elif args.action == "aggregate": if not args.segments: print("❌ 使用 aggregate 操作时必须指定 --segments 参数") return # 解析片段索引 try: segment_indices = [int(idx.strip()) for idx in args.segments.split(",")] except ValueError: print("❌ 片段索引必须是整数,用逗号分隔") return # 获取已存在的音频片段 segment_paths, original_segments = get_existing_segments() if not segment_paths: print("❌ 未找到已处理的音频片段,请先使用 process 命令处理音频文件") return print(f"找到 {len(segment_paths)} 个已处理的音频片段") # 验证索引有效性 max_idx = len(segment_paths) - 1 invalid_indices = [idx for idx in segment_indices if idx < 0 or idx > max_idx] if invalid_indices: print(f"❌ 无效的片段索引: {invalid_indices},有效范围: 0-{max_idx}") return # 聚合转录指定的片段 json_path = transcribe_aggregated_segments(segment_paths, original_segments, segment_indices) print(f"✅ 聚合转录完成,结果保存在: {json_path}") except Exception as e: print(f"错误: {e}") if __name__ == "__main__": main()