version corrigés avec parallel processing
4619a60
raw
history blame contribute delete
7.67 kB
from fastapi import APIRouter, HTTPException
from datetime import datetime
from datasets import load_dataset
from sklearn.metrics import accuracy_score
import os
import pickle
from pathlib import Path
import numpy as np
import librosa
from concurrent.futures import ThreadPoolExecutor
import multiprocessing
from .utils.evaluation import AudioEvaluationRequest
from .utils.emissions import tracker, clean_emissions_data, get_space_info
from dotenv import load_dotenv
import logging
# Configuration
load_dotenv()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
router = APIRouter()
DESCRIPTION = "Parallel Random Forest with Feature Engineering"
ROUTE = "/audio"
MODEL_PATH = Path("/app/models/audio_model.pkl")
SAMPLING_RATE = 12000
N_MFCC = 13
NUM_WORKERS = multiprocessing.cpu_count()
BATCH_SIZE = 32
def process_batch_parallel(batch):
"""Process a batch of test samples in parallel"""
features = []
labels = []
try:
logger.info(f"Batch type: {type(batch)}")
with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
futures = []
# Itérer sur le dataset
for i in range(len(batch)):
audio = batch[i]['audio']
label = batch[i]['label']
logger.info(f"Processing audio sample {i}")
logger.info(f"Audio type: {type(audio)}")
if hasattr(audio, 'array'):
logger.info(f"Audio shape: {audio['array'].shape}")
future = executor.submit(extract_features_parallel, audio)
futures.append((future, label))
for idx, (future, label) in enumerate(futures):
try:
feature = future.result()
if feature is not None:
logger.info(f"Successfully extracted features for sample {idx}")
features.append(feature)
labels.append(label)
else:
logger.warning(f"No features extracted for sample {idx}")
except Exception as e:
logger.error(f"Feature extraction error for sample {idx}: {str(e)}")
continue
logger.info(f"Successfully processed {len(features)} samples out of {len(batch)}")
return features, labels
except Exception as e:
logger.error(f"Batch processing error: {str(e)}")
return [], []
def extract_features_parallel(audio_data):
"""Optimized parallel feature extraction"""
try:
if isinstance(audio_data, dict):
if 'array' in audio_data:
audio_array = audio_data['array']
elif 'path' in audio_data:
# Si nous avons un chemin de fichier
y, sr = librosa.load(audio_data['path'], sr=SAMPLING_RATE)
audio_array = y
else:
logger.error("No array or path in audio data")
return None
else:
audio_array = audio_data
if len(audio_array) == 0:
logger.error("Empty audio array")
return None
# Conversion en mono si stéréo
y = np.mean(audio_array, axis=1) if audio_array.ndim > 1 else audio_array
# Vérification de la longueur minimale
if len(y) < SAMPLING_RATE:
logger.warning("Audio too short, padding")
y = np.pad(y, (0, SAMPLING_RATE - len(y)))
# Extraction des features
with ThreadPoolExecutor(max_workers=4) as executor:
futures = [
executor.submit(librosa.feature.mfcc, y=y, sr=SAMPLING_RATE, n_mfcc=N_MFCC),
executor.submit(librosa.feature.zero_crossing_rate, y),
executor.submit(librosa.feature.rms, y=y),
executor.submit(librosa.feature.spectral_centroid, y=y, sr=SAMPLING_RATE)
]
mfccs, zcr, rms, spectral_centroid = [f.result() for f in futures]
feature_vector = np.concatenate([
np.mean(mfccs, axis=1),
np.std(mfccs, axis=1),
[np.mean(zcr)],
[np.mean(rms)],
[np.mean(spectral_centroid)]
])
return feature_vector
except Exception as e:
logger.error(f"Feature extraction error: {str(e)}")
return None
@router.post(ROUTE, tags=["Audio Task"], description=DESCRIPTION)
async def evaluate_audio(request: AudioEvaluationRequest):
"""Evaluate audio classification with parallel processing"""
try:
logger.info("Starting audio evaluation...")
username, space_url = get_space_info()
logger.info(f"Loading dataset: {request.dataset_name}")
dataset = load_dataset(
request.dataset_name,
token=os.getenv("HF_TOKEN")
)
logger.info("Splitting dataset...")
train_test = dataset["train"].train_test_split(
test_size=request.test_size,
seed=request.test_seed
)
test_dataset = train_test["test"]
logger.info(f"Test dataset size: {len(test_dataset)}")
tracker.start()
tracker.start_task("inference")
logger.info("Processing test data...")
x_test = []
true_labels = []
for i in range(0, len(test_dataset), BATCH_SIZE):
logger.info(f"Processing batch {i}/{len(test_dataset)}")
batch = test_dataset.select(range(i, min(i + BATCH_SIZE, len(test_dataset))))
try:
features, labels = process_batch_parallel(batch)
x_test.extend(features)
true_labels.extend(labels)
except Exception as e:
logger.error(f"Error processing batch: {str(e)}")
continue
if len(x_test) == 0:
raise ValueError("No valid features could be extracted")
logger.info("Loading model...")
if not MODEL_PATH.exists():
raise FileNotFoundError(f"Model not found at {MODEL_PATH}")
with open(MODEL_PATH, 'rb') as f:
model_data = pickle.load(f)
model = model_data['model']
scaler = model_data['scaler']
logger.info("Making predictions...")
x_test = np.array(x_test)
x_test_scaled = scaler.transform(x_test) if scaler is not None else x_test
predictions = model.predict(x_test_scaled)
emissions_data = tracker.stop_task()
accuracy = accuracy_score(true_labels, predictions)
logger.info(f"Evaluation complete. Accuracy: {accuracy}")
return {
"username": username,
"space_url": space_url,
"submission_timestamp": datetime.now().isoformat(),
"model_description": DESCRIPTION,
"accuracy": float(accuracy),
"energy_consumed_wh": emissions_data.energy_consumed * 1000,
"emissions_gco2eq": emissions_data.emissions * 1000,
"emissions_data": clean_emissions_data(emissions_data),
"api_route": ROUTE,
"dataset_config": {
"dataset_name": request.dataset_name,
"test_size": request.test_size,
"test_seed": request.test_seed,
},
}
except Exception as e:
logger.error(f"Error in evaluate_audio: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"An error occurred during evaluation: {str(e)}"
)