from fastapi import APIRouter, HTTPException from datetime import datetime from datasets import load_dataset from sklearn.metrics import accuracy_score import os import pickle from pathlib import Path import numpy as np import librosa from concurrent.futures import ThreadPoolExecutor import multiprocessing from .utils.evaluation import AudioEvaluationRequest from .utils.emissions import tracker, clean_emissions_data, get_space_info from dotenv import load_dotenv import logging # Configuration load_dotenv() logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) router = APIRouter() DESCRIPTION = "Parallel Random Forest with Feature Engineering" ROUTE = "/audio" MODEL_PATH = Path("/app/models/audio_model.pkl") SAMPLING_RATE = 12000 N_MFCC = 13 NUM_WORKERS = multiprocessing.cpu_count() BATCH_SIZE = 32 def process_batch_parallel(batch): """Process a batch of test samples in parallel""" features = [] labels = [] try: logger.info(f"Batch type: {type(batch)}") with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor: futures = [] # Itérer sur le dataset for i in range(len(batch)): audio = batch[i]['audio'] label = batch[i]['label'] logger.info(f"Processing audio sample {i}") logger.info(f"Audio type: {type(audio)}") if hasattr(audio, 'array'): logger.info(f"Audio shape: {audio['array'].shape}") future = executor.submit(extract_features_parallel, audio) futures.append((future, label)) for idx, (future, label) in enumerate(futures): try: feature = future.result() if feature is not None: logger.info(f"Successfully extracted features for sample {idx}") features.append(feature) labels.append(label) else: logger.warning(f"No features extracted for sample {idx}") except Exception as e: logger.error(f"Feature extraction error for sample {idx}: {str(e)}") continue logger.info(f"Successfully processed {len(features)} samples out of {len(batch)}") return features, labels except Exception as e: logger.error(f"Batch processing error: {str(e)}") return [], [] def extract_features_parallel(audio_data): """Optimized parallel feature extraction""" try: if isinstance(audio_data, dict): if 'array' in audio_data: audio_array = audio_data['array'] elif 'path' in audio_data: # Si nous avons un chemin de fichier y, sr = librosa.load(audio_data['path'], sr=SAMPLING_RATE) audio_array = y else: logger.error("No array or path in audio data") return None else: audio_array = audio_data if len(audio_array) == 0: logger.error("Empty audio array") return None # Conversion en mono si stéréo y = np.mean(audio_array, axis=1) if audio_array.ndim > 1 else audio_array # Vérification de la longueur minimale if len(y) < SAMPLING_RATE: logger.warning("Audio too short, padding") y = np.pad(y, (0, SAMPLING_RATE - len(y))) # Extraction des features with ThreadPoolExecutor(max_workers=4) as executor: futures = [ executor.submit(librosa.feature.mfcc, y=y, sr=SAMPLING_RATE, n_mfcc=N_MFCC), executor.submit(librosa.feature.zero_crossing_rate, y), executor.submit(librosa.feature.rms, y=y), executor.submit(librosa.feature.spectral_centroid, y=y, sr=SAMPLING_RATE) ] mfccs, zcr, rms, spectral_centroid = [f.result() for f in futures] feature_vector = np.concatenate([ np.mean(mfccs, axis=1), np.std(mfccs, axis=1), [np.mean(zcr)], [np.mean(rms)], [np.mean(spectral_centroid)] ]) return feature_vector except Exception as e: logger.error(f"Feature extraction error: {str(e)}") return None @router.post(ROUTE, tags=["Audio Task"], description=DESCRIPTION) async def evaluate_audio(request: AudioEvaluationRequest): """Evaluate audio classification with parallel processing""" try: logger.info("Starting audio evaluation...") username, space_url = get_space_info() logger.info(f"Loading dataset: {request.dataset_name}") dataset = load_dataset( request.dataset_name, token=os.getenv("HF_TOKEN") ) logger.info("Splitting dataset...") train_test = dataset["train"].train_test_split( test_size=request.test_size, seed=request.test_seed ) test_dataset = train_test["test"] logger.info(f"Test dataset size: {len(test_dataset)}") tracker.start() tracker.start_task("inference") logger.info("Processing test data...") x_test = [] true_labels = [] for i in range(0, len(test_dataset), BATCH_SIZE): logger.info(f"Processing batch {i}/{len(test_dataset)}") batch = test_dataset.select(range(i, min(i + BATCH_SIZE, len(test_dataset)))) try: features, labels = process_batch_parallel(batch) x_test.extend(features) true_labels.extend(labels) except Exception as e: logger.error(f"Error processing batch: {str(e)}") continue if len(x_test) == 0: raise ValueError("No valid features could be extracted") logger.info("Loading model...") if not MODEL_PATH.exists(): raise FileNotFoundError(f"Model not found at {MODEL_PATH}") with open(MODEL_PATH, 'rb') as f: model_data = pickle.load(f) model = model_data['model'] scaler = model_data['scaler'] logger.info("Making predictions...") x_test = np.array(x_test) x_test_scaled = scaler.transform(x_test) if scaler is not None else x_test predictions = model.predict(x_test_scaled) emissions_data = tracker.stop_task() accuracy = accuracy_score(true_labels, predictions) logger.info(f"Evaluation complete. Accuracy: {accuracy}") return { "username": username, "space_url": space_url, "submission_timestamp": datetime.now().isoformat(), "model_description": DESCRIPTION, "accuracy": float(accuracy), "energy_consumed_wh": emissions_data.energy_consumed * 1000, "emissions_gco2eq": emissions_data.emissions * 1000, "emissions_data": clean_emissions_data(emissions_data), "api_route": ROUTE, "dataset_config": { "dataset_name": request.dataset_name, "test_size": request.test_size, "test_seed": request.test_seed, }, } except Exception as e: logger.error(f"Error in evaluate_audio: {str(e)}") raise HTTPException( status_code=500, detail=f"An error occurred during evaluation: {str(e)}" )