hriteshMaikap commited on
Commit
32c9322
·
verified ·
1 Parent(s): 6fd65bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -37
app.py CHANGED
@@ -4,6 +4,9 @@ import torch.nn as nn
4
  import torchaudio
5
  import json
6
  import numpy as np
 
 
 
7
  from huggingface_hub import hf_hub_download
8
  from transformers import PretrainedConfig, PreTrainedModel
9
 
@@ -13,7 +16,7 @@ class AudioLanguageClassifierConfig(PretrainedConfig):
13
 
14
  def __init__(
15
  self,
16
- num_labels=10, # Changed from 12 to 10 to match the saved model
17
  sampling_rate=16000,
18
  num_mel_bins=128,
19
  feature_size=512,
@@ -158,42 +161,94 @@ def load_model():
158
  config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
159
  mappings_path = hf_hub_download(repo_id=repo_id, filename="language_mappings.json")
160
 
161
- # Load the config
 
 
 
 
 
 
 
162
  with open(config_path, "r") as f:
163
  config_dict = json.load(f)
164
 
165
- # IMPORTANT: Override num_labels to 10 since the model was trained with 10 classes
166
- config_dict["num_labels"] = 10
167
-
168
  config = AudioLanguageClassifierConfig(**config_dict)
169
 
170
- # Load the model
171
  model = AudioLanguageClassifier(config)
172
- model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
173
- model.eval()
174
 
175
- # Load language mappings
176
- with open(mappings_path, "r") as f:
177
- mappings = json.load(f)
178
 
179
- id_to_language = {int(k): v for k, v in mappings["id_to_language"].items()}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
  return model, config, id_to_language
182
 
183
  except Exception as e:
184
- gr.Warning(f"Error loading model: {e}")
185
- # Return placeholders with error message
186
- raise gr.Error(f"Failed to load the language classification model: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
 
188
  # Function to process audio and make predictions
189
  def classify_language(audio):
 
 
 
190
  try:
191
- # Load model on first inference
192
- global model, config, id_to_language, feature_extractor
193
- if 'model' not in globals() or model is None:
194
- model, config, id_to_language = load_model()
195
- feature_extractor = AudioFeatureExtractor(config)
196
-
197
  # Get audio data
198
  sr, waveform = audio
199
 
@@ -230,29 +285,52 @@ def classify_language(audio):
230
  logits = outputs["logits"]
231
  probs = torch.nn.functional.softmax(logits, dim=1)[0]
232
 
233
- # Get top predictions
234
- num_classes = min(3, len(id_to_language))
235
- top_probs, top_ids = torch.topk(probs, num_classes)
236
-
237
- # Format results
238
- results = {}
239
- for i, (prob, pred_id) in enumerate(zip(top_probs, top_ids)):
240
- lang = id_to_language.get(pred_id.item(), f"Unknown-{pred_id.item()}")
241
- results[lang] = float(prob)
 
 
 
 
 
 
 
242
 
243
- return results
244
 
245
  except Exception as e:
246
- return {"Error": 1.0, "Details": str(e)}
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
  # Create the Gradio interface
249
  demo = gr.Interface(
250
  fn=classify_language,
251
- # Changed type from "tuple" to "numpy" to fix the error
252
  inputs=gr.Audio(sources=["microphone", "upload"], type="numpy"),
253
- outputs=gr.Label(num_top_classes=3),
 
 
 
254
  title="Indian Language Identification",
255
- description="Record or upload audio to identify the Indian language being spoken.",
 
256
  examples=[],
257
  article="""
258
  <div style="text-align: center;">
@@ -262,12 +340,11 @@ demo = gr.Interface(
262
  <li>Recording length of 3-5 seconds is ideal</li>
263
  <li>Make sure to speak a full sentence or phrase</li>
264
  </ul>
 
265
  </div>
266
  """
267
  )
268
 
269
  # Launch the app
270
  if __name__ == "__main__":
271
- # Initialize model as None to lazy-load on first inference
272
- model, config, id_to_language, feature_extractor = None, None, None, None
273
  demo.launch()
 
4
  import torchaudio
5
  import json
6
  import numpy as np
7
+ import matplotlib.pyplot as plt
8
+ from io import BytesIO
9
+ import base64
10
  from huggingface_hub import hf_hub_download
11
  from transformers import PretrainedConfig, PreTrainedModel
12
 
 
16
 
17
  def __init__(
18
  self,
19
+ num_labels=10,
20
  sampling_rate=16000,
21
  num_mel_bins=128,
22
  feature_size=512,
 
161
  config_path = hf_hub_download(repo_id=repo_id, filename="config.json")
162
  mappings_path = hf_hub_download(repo_id=repo_id, filename="language_mappings.json")
163
 
164
+ # Load language mappings first to get correct number of labels
165
+ with open(mappings_path, "r") as f:
166
+ mappings = json.load(f)
167
+
168
+ id_to_language = {int(k): v for k, v in mappings["id_to_language"].items()}
169
+ num_languages = len(id_to_language)
170
+
171
+ # Load the config with correct number of labels
172
  with open(config_path, "r") as f:
173
  config_dict = json.load(f)
174
 
175
+ config_dict["num_labels"] = num_languages
 
 
176
  config = AudioLanguageClassifierConfig(**config_dict)
177
 
178
+ # Create the model
179
  model = AudioLanguageClassifier(config)
 
 
180
 
181
+ # Load and adapt state dict as needed
182
+ state_dict = torch.load(model_path, map_location=torch.device('cpu'))
 
183
 
184
+ # Fix classifier weights and biases if needed
185
+ if 'classifier.weight' in state_dict and state_dict['classifier.weight'].size(0) != config.num_labels:
186
+ print(f"Adjusting classifier size from {state_dict['classifier.weight'].size(0)} to {config.num_labels}")
187
+ old_size = state_dict['classifier.weight'].size(0)
188
+
189
+ # Create new classifier layer with correct size
190
+ new_classifier = nn.Linear(config.feature_size, config.num_labels)
191
+
192
+ # Copy weights and biases for available classes
193
+ with torch.no_grad():
194
+ # Copy weights for the classes we have
195
+ new_classifier.weight.data[:old_size, :] = state_dict['classifier.weight']
196
+ new_classifier.bias.data[:old_size] = state_dict['classifier.bias']
197
+
198
+ # Update state dict with new weights
199
+ state_dict['classifier.weight'] = new_classifier.weight.data
200
+ state_dict['classifier.bias'] = new_classifier.bias.data
201
+
202
+ # Load the updated state dict
203
+ model.load_state_dict(state_dict)
204
+ model.eval()
205
 
206
  return model, config, id_to_language
207
 
208
  except Exception as e:
209
+ print(f"Error loading model: {e}")
210
+ import traceback
211
+ traceback.print_exc()
212
+ raise gr.Error(f"Failed to load the language classification model: {str(e)}")
213
+
214
+ # Function to create a bar chart visualization
215
+ def create_confidence_chart(probs, id_to_language):
216
+ plt.figure(figsize=(10, 5))
217
+ languages = [id_to_language[i] for i in range(len(id_to_language))]
218
+
219
+ # Sort by confidence score
220
+ indices = np.argsort(probs)[::-1]
221
+ sorted_languages = [languages[i] for i in indices]
222
+ sorted_confidences = [probs[i] for i in indices]
223
+
224
+ # Use a colormap - highest confidence gets different color
225
+ colors = ['#1f77b4'] * len(sorted_languages)
226
+ colors[0] = '#ff7f0e' # Highlight the top prediction
227
+
228
+ plt.bar(sorted_languages, sorted_confidences, color=colors)
229
+ plt.xticks(rotation=45, ha='right')
230
+ plt.title('Language Detection Confidence')
231
+ plt.xlabel('Language')
232
+ plt.ylabel('Confidence')
233
+ plt.tight_layout()
234
+
235
+ # Save plot to a bytes buffer
236
+ buf = BytesIO()
237
+ plt.savefig(buf, format='png')
238
+ plt.close()
239
+ buf.seek(0)
240
+
241
+ # Convert to base64 string for HTML embedding
242
+ img_str = base64.b64encode(buf.read()).decode('utf-8')
243
+
244
+ return f"<img src='data:image/png;base64,{img_str}' alt='Confidence Chart'>"
245
 
246
  # Function to process audio and make predictions
247
  def classify_language(audio):
248
+ if audio is None:
249
+ return {"No audio detected": 1.0}, "Please record or upload audio to analyze."
250
+
251
  try:
 
 
 
 
 
 
252
  # Get audio data
253
  sr, waveform = audio
254
 
 
285
  logits = outputs["logits"]
286
  probs = torch.nn.functional.softmax(logits, dim=1)[0]
287
 
288
+ # Only consider valid language indices
289
+ valid_indices = list(range(len(id_to_language)))
290
+ valid_probs = probs[valid_indices].cpu().numpy()
291
+
292
+ # Generate the confidence visualization
293
+ chart_html = create_confidence_chart(valid_probs, id_to_language)
294
+
295
+ # Get top 3 predictions (or all if fewer than 3)
296
+ num_classes = min(3, len(id_to_language))
297
+ top_indices = np.argsort(valid_probs)[::-1][:num_classes]
298
+
299
+ # Format results
300
+ results = {}
301
+ for idx in top_indices:
302
+ lang = id_to_language.get(idx, f"Unknown-{idx}")
303
+ results[lang] = float(valid_probs[idx])
304
 
305
+ return results, chart_html
306
 
307
  except Exception as e:
308
+ import traceback
309
+ traceback.print_exc()
310
+ return {"Error": 1.0}, f"<p>Error processing audio: {str(e)}</p>"
311
+
312
+ # Initialize model and feature extractor
313
+ try:
314
+ model, config, id_to_language = load_model()
315
+ feature_extractor = AudioFeatureExtractor(config)
316
+ languages = list(id_to_language.values())
317
+ print(f"Model loaded successfully. Found {len(languages)} languages: {languages}")
318
+ except Exception as e:
319
+ print(f"Error initializing model: {e}")
320
+ model, config, id_to_language, feature_extractor = None, None, None, None
321
+ languages = []
322
 
323
  # Create the Gradio interface
324
  demo = gr.Interface(
325
  fn=classify_language,
 
326
  inputs=gr.Audio(sources=["microphone", "upload"], type="numpy"),
327
+ outputs=[
328
+ gr.Label(num_top_classes=3),
329
+ gr.HTML(label="Confidence Chart")
330
+ ],
331
  title="Indian Language Identification",
332
+ description="Record or upload audio to identify the Indian language being spoken. " +
333
+ f"Supported languages: {', '.join(languages) if languages else 'Error loading language list'}",
334
  examples=[],
335
  article="""
336
  <div style="text-align: center;">
 
340
  <li>Recording length of 3-5 seconds is ideal</li>
341
  <li>Make sure to speak a full sentence or phrase</li>
342
  </ul>
343
+ <p>Model by <a href="https://huggingface.co/hriteshMaikap/languageClassifier" target="_blank">hriteshMaikap</a></p>
344
  </div>
345
  """
346
  )
347
 
348
  # Launch the app
349
  if __name__ == "__main__":
 
 
350
  demo.launch()