Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,18 +1,20 @@
|
|
1 |
import gradio as gr
|
2 |
-
from transformers import CLIPProcessor, CLIPModel, pipeline
|
3 |
import torch
|
4 |
from PIL import Image
|
5 |
import scipy.io.wavfile
|
6 |
|
7 |
# Load the MusicGen model
|
8 |
-
|
|
|
|
|
9 |
# Load the StreetCLIP model
|
10 |
model = CLIPModel.from_pretrained("geolocal/StreetCLIP")
|
11 |
processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP")
|
12 |
|
13 |
labels = ['Albania', 'Andorra', 'Argentina', 'Australia', 'Austria', 'Bangladesh', 'Belgium', 'Bermuda', 'Bhutan', 'Bolivia', 'Botswana', 'Brazil', 'Bulgaria', 'Cambodia', 'Canada', 'Chile', 'China', 'Colombia', 'Croatia', 'Czech Republic', 'Denmark', 'Dominican Republic', 'Egypt', 'Ecuador', 'Estonia', 'Finland', 'France', 'Germany', 'Ghana', 'Greece', 'Greenland', 'Guam', 'Guatemala', 'Hungary', 'Iceland', 'India', 'Indonesia', 'Ireland', 'Israel', 'Italy', 'Japan', 'Jordan', 'Kenya', 'Kyrgyzstan', 'Laos', 'Latvia', 'Lesotho', 'Lithuania', 'Luxembourg', 'Macedonia', 'Madagascar', 'Malaysia', 'Malta', 'Mexico', 'Monaco', 'Mongolia', 'Montenegro', 'Netherlands', 'New Zealand', 'Nigeria', 'Norway', 'Pakistan', 'Palestine', 'Peru', 'Philippines', 'Poland', 'Portugal', 'Puerto Rico', 'Romania', 'Russia', 'Rwanda','Saudi Arabia', 'Senegal', 'Serbia', 'Singapore', 'Slovakia', 'Slovenia', 'South Africa', 'South Korea', 'Spain', 'Sri Lanka', 'Swaziland', 'Sweden', 'Switzerland', 'Syria','Taiwan', 'Thailand', 'Tunisia', 'Turkey', 'Uganda', 'Ukraine', 'United Arab Emirates', 'United Kingdom', 'United States', 'Uruguay']
|
14 |
|
15 |
-
def process_image(image):
|
16 |
# Ensure the image is in the correct format
|
17 |
if isinstance(image, str):
|
18 |
image = Image.open(image)
|
@@ -28,19 +30,29 @@ def process_image(image):
|
|
28 |
|
29 |
# Get the country with the highest probability
|
30 |
country_index = probs.argmax(dim=1).item()
|
31 |
-
print(country_index)
|
32 |
country = labels[country_index]
|
33 |
-
|
34 |
# Generate music based on the country
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
# Save the generated music to the specified path
|
|
|
|
|
37 |
|
38 |
# Return the country and the path to the generated music
|
39 |
-
return country
|
40 |
|
41 |
# Define the Gradio interface
|
42 |
inputs = gr.Image(type="pil", label="Upload a photo (تحميل صورة)")
|
43 |
-
outputs = [gr.Textbox(label="Country (البلد)")]
|
44 |
|
45 |
iface = gr.Interface(
|
46 |
fn=process_image,
|
@@ -48,8 +60,8 @@ iface = gr.Interface(
|
|
48 |
outputs=outputs,
|
49 |
title="Photo to Country and Music Generator محدد الموقع من الصور بالاضافة الى انشاء م",
|
50 |
description="Upload a photo to identify the country and generate traditional music from that country. (قم بتحميل صورة لتحديد البلد وإنشاء موسيقى تقليدية من هذا البلد.)",
|
51 |
-
examples=["Egypt.jfif", "Riyadh.jpeg", "Syria.jfif", "Turkey.jfif"]
|
52 |
)
|
53 |
|
54 |
# Launch the interface
|
55 |
-
iface.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
from transformers import CLIPProcessor, CLIPModel, pipeline, AutoProcessor, MusicgenForConditionalGeneration
|
3 |
import torch
|
4 |
from PIL import Image
|
5 |
import scipy.io.wavfile
|
6 |
|
7 |
# Load the MusicGen model
|
8 |
+
#musicgen = pipeline("text-to-audio", model="facebook/musicgen-small")
|
9 |
+
musicProcessor = AutoProcessor.from_pretrained("facebook/musicgen-small")
|
10 |
+
musicgen = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
|
11 |
# Load the StreetCLIP model
|
12 |
model = CLIPModel.from_pretrained("geolocal/StreetCLIP")
|
13 |
processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP")
|
14 |
|
15 |
labels = ['Albania', 'Andorra', 'Argentina', 'Australia', 'Austria', 'Bangladesh', 'Belgium', 'Bermuda', 'Bhutan', 'Bolivia', 'Botswana', 'Brazil', 'Bulgaria', 'Cambodia', 'Canada', 'Chile', 'China', 'Colombia', 'Croatia', 'Czech Republic', 'Denmark', 'Dominican Republic', 'Egypt', 'Ecuador', 'Estonia', 'Finland', 'France', 'Germany', 'Ghana', 'Greece', 'Greenland', 'Guam', 'Guatemala', 'Hungary', 'Iceland', 'India', 'Indonesia', 'Ireland', 'Israel', 'Italy', 'Japan', 'Jordan', 'Kenya', 'Kyrgyzstan', 'Laos', 'Latvia', 'Lesotho', 'Lithuania', 'Luxembourg', 'Macedonia', 'Madagascar', 'Malaysia', 'Malta', 'Mexico', 'Monaco', 'Mongolia', 'Montenegro', 'Netherlands', 'New Zealand', 'Nigeria', 'Norway', 'Pakistan', 'Palestine', 'Peru', 'Philippines', 'Poland', 'Portugal', 'Puerto Rico', 'Romania', 'Russia', 'Rwanda','Saudi Arabia', 'Senegal', 'Serbia', 'Singapore', 'Slovakia', 'Slovenia', 'South Africa', 'South Korea', 'Spain', 'Sri Lanka', 'Swaziland', 'Sweden', 'Switzerland', 'Syria','Taiwan', 'Thailand', 'Tunisia', 'Turkey', 'Uganda', 'Ukraine', 'United Arab Emirates', 'United Kingdom', 'United States', 'Uruguay']
|
16 |
|
17 |
+
def process_image(image, audio_path="musicgen_out.wav"):
|
18 |
# Ensure the image is in the correct format
|
19 |
if isinstance(image, str):
|
20 |
image = Image.open(image)
|
|
|
30 |
|
31 |
# Get the country with the highest probability
|
32 |
country_index = probs.argmax(dim=1).item()
|
|
|
33 |
country = labels[country_index]
|
34 |
+
|
35 |
# Generate music based on the country
|
36 |
+
music_description = f"Traditional music from {country}"
|
37 |
+
#music = musicgen(music_description, forward_params={"do_sample": True})
|
38 |
+
inputs = musicProcessor(
|
39 |
+
text=[music_description],
|
40 |
+
padding=True,
|
41 |
+
return_tensors="pt",
|
42 |
+
)
|
43 |
+
audio_values = musicgen.generate(**inputs, max_new_tokens=256)
|
44 |
+
|
45 |
|
46 |
# Save the generated music to the specified path
|
47 |
+
sampling_rate = model.config.audio_encoder.sampling_rate
|
48 |
+
scipy.io.wavfile.write("musicgen_out.wav", rate=sampling_rate, data=audio_values[0, 0].numpy())
|
49 |
|
50 |
# Return the country and the path to the generated music
|
51 |
+
return country, audio_path
|
52 |
|
53 |
# Define the Gradio interface
|
54 |
inputs = gr.Image(type="pil", label="Upload a photo (تحميل صورة)")
|
55 |
+
outputs = [gr.Textbox(label="Country (البلد)"), gr.Audio(label="Generated Music (الموسيقى المولدة)")]
|
56 |
|
57 |
iface = gr.Interface(
|
58 |
fn=process_image,
|
|
|
60 |
outputs=outputs,
|
61 |
title="Photo to Country and Music Generator محدد الموقع من الصور بالاضافة الى انشاء م",
|
62 |
description="Upload a photo to identify the country and generate traditional music from that country. (قم بتحميل صورة لتحديد البلد وإنشاء موسيقى تقليدية من هذا البلد.)",
|
63 |
+
examples=["/content/Egypt.jfif", "/content/Riyadh.jpeg", "/content/Syria.jfif", "/content/Turkey.jfif"]
|
64 |
)
|
65 |
|
66 |
# Launch the interface
|
67 |
+
iface.launch(debug=True)
|