root commited on
Commit
a3f7aaa
·
1 Parent(s): ba71a6b
Files changed (1) hide show
  1. app.py +75 -0
app.py CHANGED
@@ -28,12 +28,33 @@ if "HF_TOKEN" in os.environ:
28
 
29
  # Constants
30
  GENRE_MODEL_NAME = "dima806/music_genres_classification"
 
31
  LLM_MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
32
  SAMPLE_RATE = 22050 # Standard sample rate for audio processing
33
 
34
  # Check CUDA availability (for informational purposes)
35
  CUDA_AVAILABLE = ensure_cuda_availability()
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  # Create genre classification pipeline
38
  print(f"Loading audio classification model: {GENRE_MODEL_NAME}")
39
  try:
@@ -209,6 +230,55 @@ Your lyrics:
209
 
210
  return lyrics
211
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
  def process_audio(audio_file):
213
  """Main function to process audio file, classify genre, and generate lyrics."""
214
  if audio_file is None:
@@ -218,6 +288,11 @@ def process_audio(audio_file):
218
  # Extract audio features
219
  audio_data = extract_audio_features(audio_file)
220
 
 
 
 
 
 
221
  # Classify genre
222
  top_genres = classify_genre(audio_data)
223
 
 
28
 
29
  # Constants
30
  GENRE_MODEL_NAME = "dima806/music_genres_classification"
31
+ MUSIC_DETECTION_MODEL = "MIT/ast-finetuned-audioset-10-10-0.4593"
32
  LLM_MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
33
  SAMPLE_RATE = 22050 # Standard sample rate for audio processing
34
 
35
  # Check CUDA availability (for informational purposes)
36
  CUDA_AVAILABLE = ensure_cuda_availability()
37
 
38
+ # Create music detection pipeline
39
+ print(f"Loading music detection model: {MUSIC_DETECTION_MODEL}")
40
+ try:
41
+ music_detector = pipeline(
42
+ "audio-classification",
43
+ model=MUSIC_DETECTION_MODEL,
44
+ device=0 if CUDA_AVAILABLE else -1
45
+ )
46
+ print("Successfully loaded music detection pipeline")
47
+ except Exception as e:
48
+ print(f"Error creating music detection pipeline: {str(e)}")
49
+ # Fallback to manual loading
50
+ try:
51
+ music_processor = AutoFeatureExtractor.from_pretrained(MUSIC_DETECTION_MODEL)
52
+ music_model = AutoModelForAudioClassification.from_pretrained(MUSIC_DETECTION_MODEL)
53
+ print("Successfully loaded music detection model and feature extractor")
54
+ except Exception as e2:
55
+ print(f"Error loading music detection model components: {str(e2)}")
56
+ raise RuntimeError(f"Could not load music detection model: {str(e2)}")
57
+
58
  # Create genre classification pipeline
59
  print(f"Loading audio classification model: {GENRE_MODEL_NAME}")
60
  try:
 
230
 
231
  return lyrics
232
 
233
+ def detect_music(audio_data):
234
+ """Detect if the audio is music using the MIT AST model."""
235
+ try:
236
+ # First attempt: Try using the pipeline if available
237
+ if 'music_detector' in globals():
238
+ results = music_detector(audio_data["path"])
239
+ # Look for music-related classes in the results
240
+ music_confidence = 0.0
241
+ for result in results:
242
+ label = result["label"].lower()
243
+ if any(music_term in label for music_term in ["music", "song", "singing", "instrument"]):
244
+ music_confidence = max(music_confidence, result["score"])
245
+ return music_confidence >= 0.5
246
+
247
+ # Second attempt: Use manually loaded model components
248
+ elif 'music_processor' in globals() and 'music_model' in globals():
249
+ # Process audio input with feature extractor
250
+ inputs = music_processor(
251
+ audio_data["waveform"],
252
+ sampling_rate=audio_data["sample_rate"],
253
+ return_tensors="pt"
254
+ )
255
+
256
+ with torch.no_grad():
257
+ outputs = music_model(**inputs)
258
+ predictions = outputs.logits.softmax(dim=-1)
259
+
260
+ # Get the top predictions
261
+ values, indices = torch.topk(predictions, 5)
262
+
263
+ # Map indices to labels
264
+ labels = music_model.config.id2label
265
+
266
+ # Check for music-related classes
267
+ music_confidence = 0.0
268
+ for i, (value, index) in enumerate(zip(values[0], indices[0])):
269
+ label = labels[index.item()].lower()
270
+ if any(music_term in label for music_term in ["music", "song", "singing", "instrument"]):
271
+ music_confidence = max(music_confidence, value.item())
272
+
273
+ return music_confidence >= 0.5
274
+
275
+ else:
276
+ raise ValueError("No music detection model available")
277
+
278
+ except Exception as e:
279
+ print(f"Error in music detection: {str(e)}")
280
+ return False
281
+
282
  def process_audio(audio_file):
283
  """Main function to process audio file, classify genre, and generate lyrics."""
284
  if audio_file is None:
 
288
  # Extract audio features
289
  audio_data = extract_audio_features(audio_file)
290
 
291
+ # First check if it's music
292
+ is_music = detect_music(audio_data)
293
+ if not is_music:
294
+ return "The uploaded audio does not appear to be music. Please upload a music file.", None
295
+
296
  # Classify genre
297
  top_genres = classify_genre(audio_data)
298