Spaces:
Sleeping
Sleeping
from fastapi import APIRouter, HTTPException | |
from datetime import datetime | |
from datasets import load_dataset | |
from sklearn.metrics import accuracy_score | |
import os | |
import pickle | |
from pathlib import Path | |
import numpy as np | |
import librosa | |
from concurrent.futures import ThreadPoolExecutor | |
import multiprocessing | |
from .utils.evaluation import AudioEvaluationRequest | |
from .utils.emissions import tracker, clean_emissions_data, get_space_info | |
from dotenv import load_dotenv | |
import logging | |
# Configuration | |
load_dotenv() | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
router = APIRouter() | |
DESCRIPTION = "Parallel Random Forest with Feature Engineering" | |
ROUTE = "/audio" | |
MODEL_PATH = Path("/app/models/audio_model.pkl") | |
SAMPLING_RATE = 12000 | |
N_MFCC = 13 | |
NUM_WORKERS = multiprocessing.cpu_count() | |
BATCH_SIZE = 32 | |
def process_batch_parallel(batch): | |
"""Process a batch of test samples in parallel""" | |
features = [] | |
labels = [] | |
try: | |
logger.info(f"Batch type: {type(batch)}") | |
with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor: | |
futures = [] | |
# Itérer sur le dataset | |
for i in range(len(batch)): | |
audio = batch[i]['audio'] | |
label = batch[i]['label'] | |
logger.info(f"Processing audio sample {i}") | |
logger.info(f"Audio type: {type(audio)}") | |
if hasattr(audio, 'array'): | |
logger.info(f"Audio shape: {audio['array'].shape}") | |
future = executor.submit(extract_features_parallel, audio) | |
futures.append((future, label)) | |
for idx, (future, label) in enumerate(futures): | |
try: | |
feature = future.result() | |
if feature is not None: | |
logger.info(f"Successfully extracted features for sample {idx}") | |
features.append(feature) | |
labels.append(label) | |
else: | |
logger.warning(f"No features extracted for sample {idx}") | |
except Exception as e: | |
logger.error(f"Feature extraction error for sample {idx}: {str(e)}") | |
continue | |
logger.info(f"Successfully processed {len(features)} samples out of {len(batch)}") | |
return features, labels | |
except Exception as e: | |
logger.error(f"Batch processing error: {str(e)}") | |
return [], [] | |
def extract_features_parallel(audio_data): | |
"""Optimized parallel feature extraction""" | |
try: | |
if isinstance(audio_data, dict): | |
if 'array' in audio_data: | |
audio_array = audio_data['array'] | |
elif 'path' in audio_data: | |
# Si nous avons un chemin de fichier | |
y, sr = librosa.load(audio_data['path'], sr=SAMPLING_RATE) | |
audio_array = y | |
else: | |
logger.error("No array or path in audio data") | |
return None | |
else: | |
audio_array = audio_data | |
if len(audio_array) == 0: | |
logger.error("Empty audio array") | |
return None | |
# Conversion en mono si stéréo | |
y = np.mean(audio_array, axis=1) if audio_array.ndim > 1 else audio_array | |
# Vérification de la longueur minimale | |
if len(y) < SAMPLING_RATE: | |
logger.warning("Audio too short, padding") | |
y = np.pad(y, (0, SAMPLING_RATE - len(y))) | |
# Extraction des features | |
with ThreadPoolExecutor(max_workers=4) as executor: | |
futures = [ | |
executor.submit(librosa.feature.mfcc, y=y, sr=SAMPLING_RATE, n_mfcc=N_MFCC), | |
executor.submit(librosa.feature.zero_crossing_rate, y), | |
executor.submit(librosa.feature.rms, y=y), | |
executor.submit(librosa.feature.spectral_centroid, y=y, sr=SAMPLING_RATE) | |
] | |
mfccs, zcr, rms, spectral_centroid = [f.result() for f in futures] | |
feature_vector = np.concatenate([ | |
np.mean(mfccs, axis=1), | |
np.std(mfccs, axis=1), | |
[np.mean(zcr)], | |
[np.mean(rms)], | |
[np.mean(spectral_centroid)] | |
]) | |
return feature_vector | |
except Exception as e: | |
logger.error(f"Feature extraction error: {str(e)}") | |
return None | |
async def evaluate_audio(request: AudioEvaluationRequest): | |
"""Evaluate audio classification with parallel processing""" | |
try: | |
logger.info("Starting audio evaluation...") | |
username, space_url = get_space_info() | |
logger.info(f"Loading dataset: {request.dataset_name}") | |
dataset = load_dataset( | |
request.dataset_name, | |
token=os.getenv("HF_TOKEN") | |
) | |
logger.info("Splitting dataset...") | |
train_test = dataset["train"].train_test_split( | |
test_size=request.test_size, | |
seed=request.test_seed | |
) | |
test_dataset = train_test["test"] | |
logger.info(f"Test dataset size: {len(test_dataset)}") | |
tracker.start() | |
tracker.start_task("inference") | |
logger.info("Processing test data...") | |
x_test = [] | |
true_labels = [] | |
for i in range(0, len(test_dataset), BATCH_SIZE): | |
logger.info(f"Processing batch {i}/{len(test_dataset)}") | |
batch = test_dataset.select(range(i, min(i + BATCH_SIZE, len(test_dataset)))) | |
try: | |
features, labels = process_batch_parallel(batch) | |
x_test.extend(features) | |
true_labels.extend(labels) | |
except Exception as e: | |
logger.error(f"Error processing batch: {str(e)}") | |
continue | |
if len(x_test) == 0: | |
raise ValueError("No valid features could be extracted") | |
logger.info("Loading model...") | |
if not MODEL_PATH.exists(): | |
raise FileNotFoundError(f"Model not found at {MODEL_PATH}") | |
with open(MODEL_PATH, 'rb') as f: | |
model_data = pickle.load(f) | |
model = model_data['model'] | |
scaler = model_data['scaler'] | |
logger.info("Making predictions...") | |
x_test = np.array(x_test) | |
x_test_scaled = scaler.transform(x_test) if scaler is not None else x_test | |
predictions = model.predict(x_test_scaled) | |
emissions_data = tracker.stop_task() | |
accuracy = accuracy_score(true_labels, predictions) | |
logger.info(f"Evaluation complete. Accuracy: {accuracy}") | |
return { | |
"username": username, | |
"space_url": space_url, | |
"submission_timestamp": datetime.now().isoformat(), | |
"model_description": DESCRIPTION, | |
"accuracy": float(accuracy), | |
"energy_consumed_wh": emissions_data.energy_consumed * 1000, | |
"emissions_gco2eq": emissions_data.emissions * 1000, | |
"emissions_data": clean_emissions_data(emissions_data), | |
"api_route": ROUTE, | |
"dataset_config": { | |
"dataset_name": request.dataset_name, | |
"test_size": request.test_size, | |
"test_seed": request.test_seed, | |
}, | |
} | |
except Exception as e: | |
logger.error(f"Error in evaluate_audio: {str(e)}") | |
raise HTTPException( | |
status_code=500, | |
detail=f"An error occurred during evaluation: {str(e)}" | |
) | |