Spaces:
Sleeping
Sleeping
File size: 7,671 Bytes
c7df6b9 fe4a4cb b773910 fe4a4cb 3b09640 b773910 e101d8c c7df6b9 4619a60 4d6e8c2 fe4a4cb 4619a60 4d6e8c2 4619a60 b773910 4619a60 b773910 4d6e8c2 4619a60 1c33274 4619a60 b773910 4619a60 70f5f26 4619a60 e101d8c 4619a60 b773910 4619a60 b773910 4619a60 b773910 e101d8c b773910 e101d8c b773910 e101d8c 4619a60 3b09640 e101d8c 4d6e8c2 4619a60 e101d8c 4619a60 c7df6b9 4619a60 b773910 4619a60 b773910 c7df6b9 b773910 4619a60 e101d8c c7df6b9 4619a60 b773910 c7df6b9 4619a60 b773910 4619a60 b773910 4619a60 b773910 4619a60 b773910 c7df6b9 b773910 4619a60 b773910 c7df6b9 b773910 c7df6b9 b773910 4d6e8c2 e101d8c c7df6b9 4619a60 c7df6b9 b773910 c7df6b9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 |
from fastapi import APIRouter, HTTPException
from datetime import datetime
from datasets import load_dataset
from sklearn.metrics import accuracy_score
import os
import pickle
from pathlib import Path
import numpy as np
import librosa
from concurrent.futures import ThreadPoolExecutor
import multiprocessing
from .utils.evaluation import AudioEvaluationRequest
from .utils.emissions import tracker, clean_emissions_data, get_space_info
from dotenv import load_dotenv
import logging
# Configuration
load_dotenv()
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
router = APIRouter()
DESCRIPTION = "Parallel Random Forest with Feature Engineering"
ROUTE = "/audio"
MODEL_PATH = Path("/app/models/audio_model.pkl")
SAMPLING_RATE = 12000
N_MFCC = 13
NUM_WORKERS = multiprocessing.cpu_count()
BATCH_SIZE = 32
def process_batch_parallel(batch):
"""Process a batch of test samples in parallel"""
features = []
labels = []
try:
logger.info(f"Batch type: {type(batch)}")
with ThreadPoolExecutor(max_workers=NUM_WORKERS) as executor:
futures = []
# Itérer sur le dataset
for i in range(len(batch)):
audio = batch[i]['audio']
label = batch[i]['label']
logger.info(f"Processing audio sample {i}")
logger.info(f"Audio type: {type(audio)}")
if hasattr(audio, 'array'):
logger.info(f"Audio shape: {audio['array'].shape}")
future = executor.submit(extract_features_parallel, audio)
futures.append((future, label))
for idx, (future, label) in enumerate(futures):
try:
feature = future.result()
if feature is not None:
logger.info(f"Successfully extracted features for sample {idx}")
features.append(feature)
labels.append(label)
else:
logger.warning(f"No features extracted for sample {idx}")
except Exception as e:
logger.error(f"Feature extraction error for sample {idx}: {str(e)}")
continue
logger.info(f"Successfully processed {len(features)} samples out of {len(batch)}")
return features, labels
except Exception as e:
logger.error(f"Batch processing error: {str(e)}")
return [], []
def extract_features_parallel(audio_data):
"""Optimized parallel feature extraction"""
try:
if isinstance(audio_data, dict):
if 'array' in audio_data:
audio_array = audio_data['array']
elif 'path' in audio_data:
# Si nous avons un chemin de fichier
y, sr = librosa.load(audio_data['path'], sr=SAMPLING_RATE)
audio_array = y
else:
logger.error("No array or path in audio data")
return None
else:
audio_array = audio_data
if len(audio_array) == 0:
logger.error("Empty audio array")
return None
# Conversion en mono si stéréo
y = np.mean(audio_array, axis=1) if audio_array.ndim > 1 else audio_array
# Vérification de la longueur minimale
if len(y) < SAMPLING_RATE:
logger.warning("Audio too short, padding")
y = np.pad(y, (0, SAMPLING_RATE - len(y)))
# Extraction des features
with ThreadPoolExecutor(max_workers=4) as executor:
futures = [
executor.submit(librosa.feature.mfcc, y=y, sr=SAMPLING_RATE, n_mfcc=N_MFCC),
executor.submit(librosa.feature.zero_crossing_rate, y),
executor.submit(librosa.feature.rms, y=y),
executor.submit(librosa.feature.spectral_centroid, y=y, sr=SAMPLING_RATE)
]
mfccs, zcr, rms, spectral_centroid = [f.result() for f in futures]
feature_vector = np.concatenate([
np.mean(mfccs, axis=1),
np.std(mfccs, axis=1),
[np.mean(zcr)],
[np.mean(rms)],
[np.mean(spectral_centroid)]
])
return feature_vector
except Exception as e:
logger.error(f"Feature extraction error: {str(e)}")
return None
@router.post(ROUTE, tags=["Audio Task"], description=DESCRIPTION)
async def evaluate_audio(request: AudioEvaluationRequest):
"""Evaluate audio classification with parallel processing"""
try:
logger.info("Starting audio evaluation...")
username, space_url = get_space_info()
logger.info(f"Loading dataset: {request.dataset_name}")
dataset = load_dataset(
request.dataset_name,
token=os.getenv("HF_TOKEN")
)
logger.info("Splitting dataset...")
train_test = dataset["train"].train_test_split(
test_size=request.test_size,
seed=request.test_seed
)
test_dataset = train_test["test"]
logger.info(f"Test dataset size: {len(test_dataset)}")
tracker.start()
tracker.start_task("inference")
logger.info("Processing test data...")
x_test = []
true_labels = []
for i in range(0, len(test_dataset), BATCH_SIZE):
logger.info(f"Processing batch {i}/{len(test_dataset)}")
batch = test_dataset.select(range(i, min(i + BATCH_SIZE, len(test_dataset))))
try:
features, labels = process_batch_parallel(batch)
x_test.extend(features)
true_labels.extend(labels)
except Exception as e:
logger.error(f"Error processing batch: {str(e)}")
continue
if len(x_test) == 0:
raise ValueError("No valid features could be extracted")
logger.info("Loading model...")
if not MODEL_PATH.exists():
raise FileNotFoundError(f"Model not found at {MODEL_PATH}")
with open(MODEL_PATH, 'rb') as f:
model_data = pickle.load(f)
model = model_data['model']
scaler = model_data['scaler']
logger.info("Making predictions...")
x_test = np.array(x_test)
x_test_scaled = scaler.transform(x_test) if scaler is not None else x_test
predictions = model.predict(x_test_scaled)
emissions_data = tracker.stop_task()
accuracy = accuracy_score(true_labels, predictions)
logger.info(f"Evaluation complete. Accuracy: {accuracy}")
return {
"username": username,
"space_url": space_url,
"submission_timestamp": datetime.now().isoformat(),
"model_description": DESCRIPTION,
"accuracy": float(accuracy),
"energy_consumed_wh": emissions_data.energy_consumed * 1000,
"emissions_gco2eq": emissions_data.emissions * 1000,
"emissions_data": clean_emissions_data(emissions_data),
"api_route": ROUTE,
"dataset_config": {
"dataset_name": request.dataset_name,
"test_size": request.test_size,
"test_seed": request.test_seed,
},
}
except Exception as e:
logger.error(f"Error in evaluate_audio: {str(e)}")
raise HTTPException(
status_code=500,
detail=f"An error occurred during evaluation: {str(e)}"
)
|