root
commited on
Commit
·
a3f7aaa
1
Parent(s):
ba71a6b
ss
Browse files
app.py
CHANGED
@@ -28,12 +28,33 @@ if "HF_TOKEN" in os.environ:
|
|
28 |
|
29 |
# Constants
|
30 |
GENRE_MODEL_NAME = "dima806/music_genres_classification"
|
|
|
31 |
LLM_MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
|
32 |
SAMPLE_RATE = 22050 # Standard sample rate for audio processing
|
33 |
|
34 |
# Check CUDA availability (for informational purposes)
|
35 |
CUDA_AVAILABLE = ensure_cuda_availability()
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
# Create genre classification pipeline
|
38 |
print(f"Loading audio classification model: {GENRE_MODEL_NAME}")
|
39 |
try:
|
@@ -209,6 +230,55 @@ Your lyrics:
|
|
209 |
|
210 |
return lyrics
|
211 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
def process_audio(audio_file):
|
213 |
"""Main function to process audio file, classify genre, and generate lyrics."""
|
214 |
if audio_file is None:
|
@@ -218,6 +288,11 @@ def process_audio(audio_file):
|
|
218 |
# Extract audio features
|
219 |
audio_data = extract_audio_features(audio_file)
|
220 |
|
|
|
|
|
|
|
|
|
|
|
221 |
# Classify genre
|
222 |
top_genres = classify_genre(audio_data)
|
223 |
|
|
|
28 |
|
29 |
# Constants
|
30 |
GENRE_MODEL_NAME = "dima806/music_genres_classification"
|
31 |
+
MUSIC_DETECTION_MODEL = "MIT/ast-finetuned-audioset-10-10-0.4593"
|
32 |
LLM_MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
|
33 |
SAMPLE_RATE = 22050 # Standard sample rate for audio processing
|
34 |
|
35 |
# Check CUDA availability (for informational purposes)
|
36 |
CUDA_AVAILABLE = ensure_cuda_availability()
|
37 |
|
38 |
+
# Create music detection pipeline
|
39 |
+
print(f"Loading music detection model: {MUSIC_DETECTION_MODEL}")
|
40 |
+
try:
|
41 |
+
music_detector = pipeline(
|
42 |
+
"audio-classification",
|
43 |
+
model=MUSIC_DETECTION_MODEL,
|
44 |
+
device=0 if CUDA_AVAILABLE else -1
|
45 |
+
)
|
46 |
+
print("Successfully loaded music detection pipeline")
|
47 |
+
except Exception as e:
|
48 |
+
print(f"Error creating music detection pipeline: {str(e)}")
|
49 |
+
# Fallback to manual loading
|
50 |
+
try:
|
51 |
+
music_processor = AutoFeatureExtractor.from_pretrained(MUSIC_DETECTION_MODEL)
|
52 |
+
music_model = AutoModelForAudioClassification.from_pretrained(MUSIC_DETECTION_MODEL)
|
53 |
+
print("Successfully loaded music detection model and feature extractor")
|
54 |
+
except Exception as e2:
|
55 |
+
print(f"Error loading music detection model components: {str(e2)}")
|
56 |
+
raise RuntimeError(f"Could not load music detection model: {str(e2)}")
|
57 |
+
|
58 |
# Create genre classification pipeline
|
59 |
print(f"Loading audio classification model: {GENRE_MODEL_NAME}")
|
60 |
try:
|
|
|
230 |
|
231 |
return lyrics
|
232 |
|
233 |
+
def detect_music(audio_data):
|
234 |
+
"""Detect if the audio is music using the MIT AST model."""
|
235 |
+
try:
|
236 |
+
# First attempt: Try using the pipeline if available
|
237 |
+
if 'music_detector' in globals():
|
238 |
+
results = music_detector(audio_data["path"])
|
239 |
+
# Look for music-related classes in the results
|
240 |
+
music_confidence = 0.0
|
241 |
+
for result in results:
|
242 |
+
label = result["label"].lower()
|
243 |
+
if any(music_term in label for music_term in ["music", "song", "singing", "instrument"]):
|
244 |
+
music_confidence = max(music_confidence, result["score"])
|
245 |
+
return music_confidence >= 0.5
|
246 |
+
|
247 |
+
# Second attempt: Use manually loaded model components
|
248 |
+
elif 'music_processor' in globals() and 'music_model' in globals():
|
249 |
+
# Process audio input with feature extractor
|
250 |
+
inputs = music_processor(
|
251 |
+
audio_data["waveform"],
|
252 |
+
sampling_rate=audio_data["sample_rate"],
|
253 |
+
return_tensors="pt"
|
254 |
+
)
|
255 |
+
|
256 |
+
with torch.no_grad():
|
257 |
+
outputs = music_model(**inputs)
|
258 |
+
predictions = outputs.logits.softmax(dim=-1)
|
259 |
+
|
260 |
+
# Get the top predictions
|
261 |
+
values, indices = torch.topk(predictions, 5)
|
262 |
+
|
263 |
+
# Map indices to labels
|
264 |
+
labels = music_model.config.id2label
|
265 |
+
|
266 |
+
# Check for music-related classes
|
267 |
+
music_confidence = 0.0
|
268 |
+
for i, (value, index) in enumerate(zip(values[0], indices[0])):
|
269 |
+
label = labels[index.item()].lower()
|
270 |
+
if any(music_term in label for music_term in ["music", "song", "singing", "instrument"]):
|
271 |
+
music_confidence = max(music_confidence, value.item())
|
272 |
+
|
273 |
+
return music_confidence >= 0.5
|
274 |
+
|
275 |
+
else:
|
276 |
+
raise ValueError("No music detection model available")
|
277 |
+
|
278 |
+
except Exception as e:
|
279 |
+
print(f"Error in music detection: {str(e)}")
|
280 |
+
return False
|
281 |
+
|
282 |
def process_audio(audio_file):
|
283 |
"""Main function to process audio file, classify genre, and generate lyrics."""
|
284 |
if audio_file is None:
|
|
|
288 |
# Extract audio features
|
289 |
audio_data = extract_audio_features(audio_file)
|
290 |
|
291 |
+
# First check if it's music
|
292 |
+
is_music = detect_music(audio_data)
|
293 |
+
if not is_music:
|
294 |
+
return "The uploaded audio does not appear to be music. Please upload a music file.", None
|
295 |
+
|
296 |
# Classify genre
|
297 |
top_genres = classify_genre(audio_data)
|
298 |
|