from io import BytesIO from urllib.request import urlopen import soundfile import torch from datasets import load_dataset, Audio import numpy as np from transformers import AutoModel, AutoProcessor, BatchFeature from tqdm import tqdm import json import os import time from datetime import datetime from whisper_normalizer.english import EnglishTextNormalizer from whisper_normalizer.basic import BasicTextNormalizer import sacrebleu from jiwer import cer, wer from torch.utils.data import Dataset, DataLoader import soundfile as sf import re normalizer = { "en_us" : EnglishTextNormalizer(), "ko_kr" : BasicTextNormalizer() } # 모델 및 프로세서 로드 model_id = "junnei/gemma-3-4b-it-speech" revision = "main" #"v1.0" model = AutoModel.from_pretrained( model_id, device_map="auto", revision = revision, trust_remote_code=True ).eval() processor = AutoProcessor.from_pretrained( model_id, revision = revision, trust_remote_code=True ) # 결과 저장 디렉토리 생성 results_dir = f"evaluation_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}" os.makedirs(results_dir, exist_ok=True) INSTRUCTION = { "ast": "Translate the audio to {0}.", "asr": "Transcribe the audio clip into text.", } class BaseAudioDataset(Dataset): def __init__(self, processor, split, sampling_rate=16000, debug=False): self.processor = processor self.training = "train" in split self.debug = debug self.sampling_rate = sampling_rate self.name = "" def set_dataset_name(self, name): self.name = name @staticmethod def filter_corrupted_files(data, audio_field, text_fields, dataset_name, sampling_rate=16000, debug=True): original_size = len(data) data = data.cast_column(audio_field, Audio(decode=False)) def identify_corrupted_files(example): try: sf.read(example[audio_field]["path"]) for field in text_fields: if example[field].replace('"', '') == "": return False return True except Exception: return False data = data.filter(identify_corrupted_files, num_proc=16) validated_size = len(data) # 오디오 디코딩 data = data.cast_column(audio_field, Audio(sampling_rate=sampling_rate, decode=True)) if debug: print(f"데이터셋: {dataset_name}") print(f"원본 데이터 개수: {original_size}") print(f"필터링 후 데이터 개수: {validated_size}") print(f"필터링 비율: {validated_size/original_size:.2%}") return data @staticmethod def filter_by_audio_length(data, audio_field, min_sec=2, max_sec=20, debug=True): original_size = len(data) def filter_audio_by_length(example): try: audio = example[audio_field]['array'] channel = 1 if hasattr(audio, 'ndim') and audio.ndim > 1: channel = audio.ndim audio = audio.squeeze() audio_length = len(audio) / example[audio_field]['sampling_rate'] / channel return min_sec <= audio_length <= max_sec except Exception as e: if debug: print(f"오류 발생: {str(e)[:100]}... - 샘플 제외됨") return False data = data.filter(filter_audio_by_length, num_proc=16) filtered_size = len(data) if debug: print(f"길이 필터링 전 데이터 개수: {original_size}") print(f"길이 필터링 후 데이터 개수: {filtered_size}") print(f"필터링 비율: {filtered_size/original_size:.2%}") return data def prepare_model_inputs(self, audio_array, instruction, answer_text): user_message = { 'role': 'user', 'content': '' + instruction, } prompt = self.processor.tokenizer.apply_chat_template( [user_message], tokenize=False, add_generation_prompt=True, add_bos=True ) inputs = self.processor( text=prompt, audio=[audio_array], add_special_tokens=False, return_tensors='pt' ) input_ids = inputs.input_ids token_type_ids = inputs.token_type_ids return { 'input_ids': input_ids, 'token_type_ids': token_type_ids, 'input_audio_embeds': inputs.input_audio_embeds, 'audio_embed_sizes': inputs.audio_embed_sizes, 'input_modes': inputs.input_modes, 'answer': answer_text, } # CoVoST2 Dataset Class class CoVoSTDataset(BaseAudioDataset): def __init__(self, processor, data_dir, split, ast=False, lang=("en_ko", "Korean"), sampling_rate=16000, debug=False): super().__init__(processor, split, sampling_rate, debug) self.set_dataset_name("CoVoST") self.ast = ast self.lang = lang[0] self.data = load_dataset("junnei/covost2", lang[0], data_dir=data_dir, split=split, trust_remote_code=True ) text_fields = ["sentence", "translation"] if ast else ["sentence"] self.data = self.filter_corrupted_files(self.data, "audio", text_fields, "CoVoST") # (Optional) Audio length Filtering self.data = self.filter_by_audio_length(self.data, "audio") # Instruction Setting self.instruction = INSTRUCTION["ast"].format(lang[1]) if ast else INSTRUCTION["asr"] def __len__(self): return len(self.data) def __getitem__(self, idx): data = self.data[idx] if self.ast: answer_text = data["translation"] else: answer_text = data["sentence"].replace('"', '') return self.prepare_model_inputs( data["audio"]["array"], self.instruction, answer_text ) # Libri Speech Dataset Class class LibriSpeechDataset(BaseAudioDataset): def __init__(self, processor, subset, split, sampling_rate=16000, debug=False): super().__init__(processor, split, sampling_rate, debug) self.set_dataset_name(f"LibriSpeech_{subset}") # only ASR self.ast = False self.lang = "en" if split == "train": split = "train.360" # load dataset self.data = load_dataset("fixie-ai/librispeech_asr", subset, split=split, trust_remote_code=True ) # (Optional) Audio length Filtering self.data = self.filter_by_audio_length(self.data, "audio") # Instruction Setting self.instruction = INSTRUCTION["asr"] def __len__(self): return len(self.data) def __getitem__(self, idx): data = self.data[idx] # Libri Speech is only for ASR answer_text = data["text"].replace('"', '') return self.prepare_model_inputs( data["audio"]["array"], self.instruction, answer_text ) # Fleurs Dataset Class class FleursDataset(BaseAudioDataset): def __init__(self, processor, split, source_lang, target_lang=None, mode="asr", sampling_rate=16000, debug=False): super().__init__(processor, split, sampling_rate, debug) self.set_dataset_name("Fleurs") # Mode Setting (ASR or AST) if mode not in ["asr", "ast"]: raise ValueError("mode must be 'asr' or 'ast'.") self.mode = mode self.ast = (mode == "ast") self.source_lang = source_lang # Language name mapping (expand if needed) self.lang_names = { 'en_us': 'English', 'ko_kr': 'Korean' } # load dataset - source language dataset self.data = load_dataset("google/fleurs", source_lang, split=split, trust_remote_code=True ) # (Optional) Audio length Filtering self.data = self.filter_by_audio_length(self.data, "audio") # When AST mode, load target language dataset. if self.ast: if target_lang is None: raise ValueError("AST mode requires target_lang.") self.target_lang = target_lang self.lang = f"{source_lang}_{target_lang}" # load dataset - target language dataset (for translation) target_data = load_dataset("google/fleurs", target_lang, split=split, trust_remote_code=True ) source_dict = {item['id']: item for item in self.data} target_dict = {item['id']: item for item in target_data} # only Common ID, add translation fields common_ids = set(source_dict.keys()) & set(target_dict.keys()) print(f"FLEURS AST Common data filtering: {len(self.data)} -> {len(common_ids)}") self.data = [ {**source_dict[id], 'translation': target_dict[id]['transcription']} for id in common_ids ] # Instruction Setting - use target language name target_lang_name = self.lang_names.get(target_lang, target_lang.capitalize()) self.instruction = INSTRUCTION["ast"].format(target_lang_name) else: # ASR mode self.lang = source_lang self.instruction = INSTRUCTION["asr"] if self.debug: print(f"FLEURS dataset loaded: {self.mode.upper()} mode") print(f"source lang: {source_lang} ({self.lang_names.get(source_lang, source_lang)})") if self.ast: print(f"target lang: {target_lang} ({self.lang_names.get(target_lang, target_lang)})") print(f"dataset size: {len(self.data)}") def __len__(self): return len(self.data) def __getitem__(self, idx): data = self.data[idx] audio_array = data["audio"]["array"] if self.ast: answer_text = data["translation"] else: answer_text = data["transcription"] return self.prepare_model_inputs( audio_array, self.instruction, answer_text ) def pad_sequence(sequences, padding_side='left', padding_value=0): """ Pad a list of sequences to the same length. sequences: list of tensors in [seq_len, *] shape """ assert padding_side in ['right', 'left'] max_size = sequences[0].size() trailing_dims = max_size[1:] max_len = max(len(seq) for seq in sequences) batch_size = len(sequences) output = sequences[0].new_full((batch_size, max_len) + trailing_dims, padding_value) for i, seq in enumerate(sequences): length = seq.size(0) if padding_side == 'right': output.data[i, :length] = seq else: output.data[i, -length:] = seq return output def cat_with_pad(tensors, dim, padding_value=0): """ cat along dim, while pad to max for all other dims """ ndim = tensors[0].dim() assert all( t.dim() == ndim for t in tensors[1:] ), 'All tensors must have the same number of dimensions' out_size = [max(t.shape[i] for t in tensors) for i in range(ndim)] out_size[dim] = sum(t.shape[dim] for t in tensors) output = tensors[0].new_full(out_size, padding_value) index = 0 for t in tensors: # Create a slice list where every dimension except dim is full slice slices = [slice(0, t.shape[d]) for d in range(ndim)] # Update only the concat dimension slice slices[dim] = slice(index, index + t.shape[dim]) output[slices] = t index += t.shape[dim] return output def covost_collate_fn(batch): input_ids_list = [] input_audio_embeds_list = [] audio_embed_sizes_list = [] audio_attention_mask_list = [] input_modes_list = [] answer_list = [] for inputs in batch: input_ids_list.append(inputs['input_ids'][0]) input_audio_embeds_list.append(inputs['input_audio_embeds']) audio_embed_sizes_list.append(inputs['audio_embed_sizes']) audio_attention_mask_list.append( inputs['input_audio_embeds'].new_full((inputs['input_audio_embeds'].size(1),), True, dtype=torch.bool) ) input_modes_list.append(inputs['input_modes']) answer_list.append(inputs['answer']) try: input_ids = pad_sequence(input_ids_list, padding_side='left', padding_value=0) audio_attention_mask = ( pad_sequence(audio_attention_mask_list, padding_side='right', padding_value=False) if len(audio_attention_mask_list) > 1 else None ) except Exception as e: print(e) print(input_ids_list) print(audio_attention_mask) raise attention_mask = (input_ids != 0).long() input_audio_embeds = cat_with_pad(input_audio_embeds_list, dim=0) audio_embed_sizes = torch.cat(audio_embed_sizes_list) input_modes = torch.cat(input_modes_list) return BatchFeature( { 'input_ids': input_ids, 'attention_mask': attention_mask, 'input_audio_embeds': input_audio_embeds, 'audio_embed_sizes': audio_embed_sizes, 'audio_attention_mask': audio_attention_mask, 'input_modes': input_modes, 'answer': answer_list, } ) def save_results(results, dataset_name, task, source_lang, target_lang=None, sample_idx=None): """결과를 JSON 파일로 저장""" filename = f"{task}_{dataset_name}_{source_lang}" if target_lang: filename += f"_to_{target_lang}" if sample_idx is not None: filename += f"_sample_{sample_idx}" filepath = os.path.join(results_dir, f"{filename}.json") # 결과에 타임스탬프 추가 results["timestamp"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S") with open(filepath, 'w', encoding='utf-8') as f: json.dump(results, f, ensure_ascii=False, indent=2) print(f"결과가 {filepath}에 저장되었습니다.") return filepath def evaluate_task(dataset, source_lang, target_lang, num_samples=-1, batch_size = 32, is_asr=True): """ASR(자동 음성 인식) 성능 평가""" task_type = "asr" if is_asr else "translation" eval_lang = source_lang if is_asr else target_lang eval_normalizer = normalizer[eval_lang] sample_results = [] # 샘플 수 처리 if num_samples > 0 and num_samples < len(dataset): indices = np.random.choice(len(dataset), num_samples, replace=False) dataset = dataset.select(indices) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, collate_fn=covost_collate_fn) evaluated_samples = {} # 배치 단위로 처리 for batch_idx, batch in enumerate(tqdm(dataloader)): batch_references = batch.pop("answer") # GPU로 이동 if torch.cuda.is_available(): batch = {k: v.to("cuda") for k, v in batch.items()} # 배치 추론 with torch.inference_mode(): generate_ids = model.generate(**batch, max_new_tokens=256, #temperature = 1.0, top_p = 0.95, top_k = 64, do_sample=True ) input_lengths = batch['input_ids'].shape[1] generate_ids = generate_ids[:, input_lengths:] # 디코딩 batch_predictions = processor.batch_decode( generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False ) # 결과 저장 for i, (reference, prediction) in enumerate(zip(batch_references, batch_predictions)): idx = batch_idx * batch_size + i sample_result = { "id": idx, "reference": reference, "prediction": prediction } sample_results.append(sample_result) # 10배치마다 중간 결과 저장 if (batch_idx + 1) % 10 == 0: temp_results = [] # 모든 샘플에 대해 처리 for item in sample_results: sample_id = item["id"] # 이미 평가된 샘플은 평가 결과를 재사용 if sample_id in evaluated_samples: temp_item = item.copy() temp_item.update(evaluated_samples[sample_id]) temp_results.append(temp_item) else: # 아직 평가되지 않은 샘플은 새로 평가 temp_item = item.copy() try: ref = eval_normalizer(item["reference"]) pred = eval_normalizer(item["prediction"]) # BLEU, WER/CER 계산 utt_bleu = sacrebleu.sentence_bleu(pred, [ref]).score utt_cer = round(cer(re.sub(r"\s+", "", ref), re.sub(r"\s+", "", pred)) * 100, 2) utt_wer = round(wer(ref, pred) * 100, 2) metrics = { "bleu": utt_bleu, "cer": utt_cer, "wer": utt_wer } # 평가 결과 저장 evaluated_samples[sample_id] = metrics temp_item.update(metrics) except Exception as e: print(f"Error evaluating sample {sample_id}: {e}") # 오류 발생 시 기본값 설정 metrics = { "bleu": 0, "cer": 100, "wer": 100, "error": str(e) } evaluated_samples[sample_id] = metrics temp_item.update(metrics) temp_results.append(temp_item) partial_results = { "task": task_type, "source_lang": source_lang, "target_lang": target_lang, "num_samples": len(temp_results), "sample_results": temp_results } save_results(partial_results, dataset.name, task_type, source_lang, target_lang) for item in sample_results: ref = eval_normalizer(item["reference"]) pred = eval_normalizer(item["prediction"]) # BLEU, WER/CER 계산 utt_bleu = sacrebleu.sentence_bleu(pred, [ref]).score utt_cer = round(cer(re.sub(r"\s+", "", ref), re.sub(r"\s+", "", pred)) * 100, 2) utt_wer = round(wer(ref, pred) * 100, 2) item.update({ "bleu": utt_bleu, "cer": utt_cer, "wer": utt_wer }) avg_bleu = sum(item["bleu"] for item in sample_results) / len(sample_results) avg_cer = sum(item["cer"] for item in sample_results) / len(sample_results) avg_wer = sum(item["wer"] for item in sample_results) / len(sample_results) results = { "dataset": dataset.name, "task": task_type, "source_lang": source_lang, "target_lang": target_lang, "num_samples": len(sample_results), "metrics": { "bleu": avg_bleu, "cer": avg_cer, "wer": avg_wer }, "sample_results": sample_results } # 최종 결과 저장 save_results(results, dataset.name, task_type, source_lang, target_lang) return results # 메인 실행 코드 if __name__ == "__main__": # 평가할 언어 목록 (소스 언어) source_languages = [ #("ko_kr", "Korean"), ("en_us", "English"), # 영어 (미국) ] # 번역 대상 언어 목록 (코드, 이름) target_languages = [ #("en_us", "English"), ("ko_kr", "Korean"), ] data_dir = { #"ko_kr" : "/workspace/CommonVoice/ko", "en_us" : "/workspace/CommonVoice/EN", } # 샘플 수 설정 (-1은 전체 데이터셋 사용) num_samples = -1 batch_size = 32 # 모든 소스 언어에 대해 ASR 평가 for source_lang, target_lang in zip(source_languages, target_languages): print(f"\n===== {source_lang[0]} ASR 평가 시작 =====") # 데이터셋 로드 split = "test" datasets = [] # Covost ASR mode (English -> English text) covost = CoVoSTDataset( processor=processor, data_dir="/workspace/CommonVoice/EN", split=split, ast=False, lang=("en_ko", "Korean") ) datasets.append(covost) # Libri Speech Clean ASR mode (English -> English text) libri_speech_clean = LibriSpeechDataset( processor=processor, subset="clean", split=split ) datasets.append(libri_speech_clean) # Libri Speech Other ASR mode (English -> English text) libri_speech_other = LibriSpeechDataset( processor=processor, subset="other", split=split ) datasets.append(libri_speech_other) # Fleurs ASR mode (English -> English text) fleurs = FleursDataset( processor=processor, split=split, source_lang="en_us", # English mode="asr" ) datasets.append(fleurs) for dataset in datasets: # ASR 평가 asr_results = evaluate_task(dataset, source_lang[0], target_lang[0], num_samples, batch_size=batch_size, is_asr = True) print(f"\n=== {asr_results.get('dataset', 'Dataset')} | {source_lang[0]} ASR 결과 ===") print(f"BLEU: {asr_results.get('metrics', {}).get('bleu', 'N/A')}") print(f"WER: {asr_results.get('metrics', {}).get('wer', 'N/A')}") print(f"CER: {asr_results.get('metrics', {}).get('cer', 'N/A')}") try: print(f"\n===== {source_lang[0]} -> {target_lang[0]} 번역 평가 시작 =====") datasets = [] # Covost AST mode (English -> Korean text) covost = CoVoSTDataset( processor=processor, data_dir="/workspace/CommonVoice/EN", split=split, ast=True, lang=("en_ko", "Korean") ) datasets.append(covost) # Fleurs AST mode (English -> Korean text) fleurs = FleursDataset( processor=processor, split=split, source_lang="en_us", # English target_lang="ko_kr", # Korean mode="ast" ) datasets.append(fleurs) for dataset in datasets: # 번역 평가 translation_results = evaluate_task(dataset, source_lang[0], target_lang[0], num_samples, batch_size=batch_size, is_asr = False) print(f"\n=== {translation_results.get('dataset', 'Dataset')} | {source_lang[0]} -> {target_lang[0]} 번역 결과 ===") print(f"BLEU: {translation_results.get('metrics', {}).get('bleu', 'N/A')}") print(f"WER: {translation_results.get('metrics', {}).get('wer', 'N/A')}") print(f"CER: {translation_results.get('metrics', {}).get('cer', 'N/A')}") except Exception as e: error_info = { "error": str(e), "source_lang": source_lang[0], "target_lang": target_lang[0], "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") } error_file = os.path.join(results_dir, f"error_translation_{source_lang[0]}_to_{target_lang[0]}_global.json") with open(error_file, 'w') as f: json.dump(error_info, f, indent=2) print(f"{source_lang[0]} -> {target_lang[0]} 번역 평가 중 오류 발생: {str(e)}") continue print(f"\n모든 평가가 완료되었습니다. 결과는 {results_dir} 디렉토리에 저장되었습니다.")