[email protected] commited on
Commit
c7df6b9
·
1 Parent(s): b88b8d2

fixing audio.py

Browse files
Files changed (1) hide show
  1. tasks/audio.py +118 -113
tasks/audio.py CHANGED
@@ -1,57 +1,52 @@
1
- from fastapi import APIRouter
2
  from datetime import datetime
3
- from datasets import load_dataset
4
  from sklearn.metrics import accuracy_score
5
  import os
6
  import joblib
7
- import librosa
8
  import numpy as np
9
- import logging
 
10
 
11
  from .utils.evaluation import AudioEvaluationRequest
12
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
13
 
14
- from dotenv import load_dotenv
15
- load_dotenv()
16
-
17
- # Configure logging
18
- logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
19
- logger = logging.getLogger(__name__)
20
-
21
  router = APIRouter()
22
 
23
  DESCRIPTION = "Chainsaw Detection Model"
24
  ROUTE = "/audio"
25
 
26
- # Load the trained model and scaler
27
- MODEL_PATH = "models/audio_model.joblib"
28
- model_data = joblib.load(MODEL_PATH)
29
- model = model_data["model"]
30
- scaler = model_data["scaler"]
31
-
32
- def extract_features(audio_array):
33
- """Extract features from audio array."""
34
- logger.debug("Extracting features from audio array...")
 
 
35
  try:
36
- # Ensure the audio is in mono
37
- if len(audio_array.shape) > 1:
38
- audio_array = np.mean(audio_array, axis=1)
39
 
40
- # Extract MFCC features
41
  mfccs = librosa.feature.mfcc(
42
- y=audio_array,
43
- sr=12000,
44
  n_mfcc=13,
45
  n_fft=2048,
46
  hop_length=512
47
  )
48
 
49
  # Extract additional features
50
- zcr = librosa.feature.zero_crossing_rate(audio_array)
51
- rms = librosa.feature.rms(y=audio_array)
52
- spectral_centroid = librosa.feature.spectral_centroid(y=audio_array, sr=12000)
53
 
54
- # Compute statistics
55
  feature_vector = np.concatenate([
56
  np.mean(mfccs, axis=1),
57
  np.std(mfccs, axis=1),
@@ -60,95 +55,105 @@ def extract_features(audio_array):
60
  [np.mean(spectral_centroid)]
61
  ])
62
 
63
- logger.debug("Features extracted successfully.")
64
  return feature_vector
 
65
  except Exception as e:
66
- logger.error(f"Error extracting features: {e}")
67
- return None
68
 
69
  @router.post(ROUTE, tags=["Audio Task"], description=DESCRIPTION)
70
  async def evaluate_audio(request: AudioEvaluationRequest):
71
- """
72
- Evaluate audio classification for rainforest sound detection.
73
-
74
- Current Model: Chainsaw Detection Model
75
- - Uses a pre-trained RandomForestClassifier to detect chainsaw sounds.
76
- """
77
- logger.info("Starting audio evaluation...")
78
-
79
- # Get space info
80
- username, space_url = get_space_info()
81
- logger.info(f"Space info retrieved: username={username}, space_url={space_url}")
82
-
83
- # Load and prepare the dataset
84
- logger.info(f"Loading dataset '{request.dataset_name}'...")
85
  try:
86
- dataset = load_dataset(request.dataset_name, token=os.getenv("HF_TOKEN"))
87
- logger.info("Dataset loaded successfully.")
88
- except Exception as e:
89
- logger.error(f"Failed to load dataset: {e}")
90
- raise
91
-
92
- # Split dataset
93
- logger.info(f"Splitting dataset with test_size={request.test_size}, test_seed={request.test_seed}...")
94
- train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
95
- test_dataset = train_test["test"]
96
- logger.info(f"Dataset split into {len(test_dataset)} test samples.")
97
-
98
- # Start tracking emissions
99
- logger.info("Starting emissions tracking...")
100
- tracker.start()
101
- tracker.start_task("inference")
102
-
103
- # Prepare lists to hold predictions and true labels
104
- predictions = []
105
- true_labels = []
106
- logger.info("Starting inference on test dataset...")
107
-
108
- # Loop through each audio sample in the test dataset
109
- for i, sample in enumerate(test_dataset):
110
- logger.debug(f"Processing sample {i + 1}...")
111
- audio_array = sample["audio"]["array"]
112
- label = sample["label"]
113
-
114
- # Extract features
115
- features = extract_features(audio_array)
116
- if features is not None:
117
- # Scale the features
118
- features_scaled = scaler.transform([features])
119
- # Make prediction
120
- prediction = model.predict(features_scaled)[0]
121
- predictions.append(prediction)
122
- true_labels.append(label)
123
- else:
124
- logger.warning(f"Skipping sample {i + 1} due to feature extraction error.")
125
- continue
126
-
127
- # Stop tracking emissions
128
- emissions_data = tracker.stop_task()
129
- logger.info("Inference completed. Stopping emissions tracking.")
130
-
131
- # Calculate accuracy
132
- accuracy = accuracy_score(true_labels, predictions)
133
- logger.info(f"Accuracy calculated: {accuracy:.4f}")
134
 
135
- # Prepare results dictionary
136
- results = {
137
- "username": username,
138
- "space_url": space_url,
139
- "submission_timestamp": datetime.now().isoformat(),
140
- "model_description": DESCRIPTION,
141
- "accuracy": float(accuracy),
142
- "energy_consumed_wh": emissions_data.energy_consumed * 1000,
143
- "emissions_gco2eq": emissions_data.emissions * 1000,
144
- "emissions_data": clean_emissions_data(emissions_data),
145
- "api_route": ROUTE,
146
- "dataset_config": {
147
- "dataset_name": request.dataset_name,
148
- "test_size": request.test_size,
149
- "test_seed": request.test_seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  }
151
- }
152
 
153
- logger.info("Audio evaluation completed successfully.")
154
- return results
 
 
 
 
1
+ from fastapi import APIRouter, HTTPException
2
  from datetime import datetime
3
+ from datasets import load_dataset, get_dataset_config_names
4
  from sklearn.metrics import accuracy_score
5
  import os
6
  import joblib
 
7
  import numpy as np
8
+ import librosa
9
+ from pathlib import Path
10
 
11
  from .utils.evaluation import AudioEvaluationRequest
12
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
13
 
14
+ # Router setup
 
 
 
 
 
 
15
  router = APIRouter()
16
 
17
  DESCRIPTION = "Chainsaw Detection Model"
18
  ROUTE = "/audio"
19
 
20
+ # Model loading
21
+ MODEL_PATH = Path(__file__).parent.parent / "models" / "audio_model.joblib"
22
+ try:
23
+ model_data = joblib.load(MODEL_PATH)
24
+ model = model_data["model"]
25
+ scaler = model_data["scaler"]
26
+ except Exception as e:
27
+ raise RuntimeError(f"Failed to load model: {e}")
28
+
29
+ def extract_features(audio_array, sr=12000):
30
+ """Extract audio features using Librosa"""
31
  try:
32
+ # Convert to mono if stereo
33
+ y = np.mean(audio_array, axis=1) if len(audio_array.shape) > 1 else audio_array
 
34
 
35
+ # Extract MFCCs
36
  mfccs = librosa.feature.mfcc(
37
+ y=y,
38
+ sr=sr,
39
  n_mfcc=13,
40
  n_fft=2048,
41
  hop_length=512
42
  )
43
 
44
  # Extract additional features
45
+ zcr = librosa.feature.zero_crossing_rate(y)
46
+ rms = librosa.feature.rms(y=y)
47
+ spectral_centroid = librosa.feature.spectral_centroid(y=y, sr=sr)
48
 
49
+ # Calculate statistics
50
  feature_vector = np.concatenate([
51
  np.mean(mfccs, axis=1),
52
  np.std(mfccs, axis=1),
 
55
  [np.mean(spectral_centroid)]
56
  ])
57
 
 
58
  return feature_vector
59
+
60
  except Exception as e:
61
+ raise HTTPException(status_code=400, detail=f"Feature extraction failed: {str(e)}")
 
62
 
63
  @router.post(ROUTE, tags=["Audio Task"], description=DESCRIPTION)
64
  async def evaluate_audio(request: AudioEvaluationRequest):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  try:
66
+ # Get Space info
67
+ username, space_url = get_space_info()
68
+
69
+ # Load dataset with proper error handling
70
+ try:
71
+ # Get available configs
72
+ configs = get_dataset_config_names(request.dataset_name)
73
+
74
+ # Set up dataset loading arguments
75
+ dataset_args = {
76
+ "path": request.dataset_name,
77
+ "token": os.getenv("HF_TOKEN"),
78
+ "trust_remote_code": True
79
+ }
80
+
81
+ # If configs exist, automatically use 'default' if it's the only one
82
+ if configs:
83
+ if len(configs) == 1 and configs[0] == 'default':
84
+ dataset_args["name"] = "default"
85
+ else:
86
+ raise HTTPException(
87
+ status_code=400,
88
+ detail=f"Config name is required for this dataset. Available configs: {configs}"
89
+ )
90
+
91
+ dataset = load_dataset(**dataset_args)
92
+
93
+ except Exception as e:
94
+ raise HTTPException(
95
+ status_code=400,
96
+ detail=f"Failed to load dataset: {str(e)}"
97
+ )
98
+
99
+ # Split dataset
100
+ split = dataset["train"].train_test_split(
101
+ test_size=request.test_size,
102
+ seed=request.test_seed
103
+ )
104
+ test_data = split["test"]
 
 
 
 
 
 
 
 
 
105
 
106
+ # Track emissions
107
+ tracker.start()
108
+ tracker.start_task("inference")
109
+
110
+ # Process features
111
+ features = []
112
+ valid_samples = []
113
+ for sample in test_data:
114
+ try:
115
+ if 'audio' in sample and isinstance(sample['audio'], dict) and 'array' in sample['audio']:
116
+ feature = extract_features(sample['audio']['array'])
117
+ if feature is not None:
118
+ features.append(feature)
119
+ valid_samples.append(sample)
120
+ except Exception as e:
121
+ print(f"Skipping sample due to error: {e}")
122
+ continue
123
+
124
+ if not features:
125
+ raise HTTPException(
126
+ status_code=400,
127
+ detail="No valid features could be extracted from the audio samples"
128
+ )
129
+
130
+ # Scale features and make predictions
131
+ scaled_features = scaler.transform(features)
132
+ predictions = model.predict(scaled_features)
133
+ true_labels = [sample["label"] for sample in valid_samples]
134
+
135
+ # Calculate results
136
+ emissions_data = tracker.stop_task()
137
+
138
+ return {
139
+ "username": username,
140
+ "space_url": space_url,
141
+ "submission_timestamp": datetime.now().isoformat(),
142
+ "model_description": DESCRIPTION,
143
+ "accuracy": float(accuracy_score(true_labels, predictions)),
144
+ "energy_consumed_wh": emissions_data.energy_consumed * 1000,
145
+ "emissions_gco2eq": emissions_data.emissions * 1000,
146
+ "emissions_data": clean_emissions_data(emissions_data),
147
+ "api_route": ROUTE,
148
+ "dataset_config": {
149
+ "dataset_name": request.dataset_name,
150
+ "test_size": request.test_size,
151
+ "test_seed": request.test_seed
152
+ }
153
  }
 
154
 
155
+ except Exception as e:
156
+ raise HTTPException(
157
+ status_code=500,
158
+ detail=f"An error occurred during audio evaluation: {str(e)}"
159
+ )