Faisal-Data commited on
Commit
1352014
·
verified ·
1 Parent(s): b14bfc0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -9
app.py CHANGED
@@ -1,18 +1,20 @@
1
  import gradio as gr
2
- from transformers import CLIPProcessor, CLIPModel, pipeline
3
  import torch
4
  from PIL import Image
5
  import scipy.io.wavfile
6
 
7
  # Load the MusicGen model
8
-
 
 
9
  # Load the StreetCLIP model
10
  model = CLIPModel.from_pretrained("geolocal/StreetCLIP")
11
  processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP")
12
 
13
  labels = ['Albania', 'Andorra', 'Argentina', 'Australia', 'Austria', 'Bangladesh', 'Belgium', 'Bermuda', 'Bhutan', 'Bolivia', 'Botswana', 'Brazil', 'Bulgaria', 'Cambodia', 'Canada', 'Chile', 'China', 'Colombia', 'Croatia', 'Czech Republic', 'Denmark', 'Dominican Republic', 'Egypt', 'Ecuador', 'Estonia', 'Finland', 'France', 'Germany', 'Ghana', 'Greece', 'Greenland', 'Guam', 'Guatemala', 'Hungary', 'Iceland', 'India', 'Indonesia', 'Ireland', 'Israel', 'Italy', 'Japan', 'Jordan', 'Kenya', 'Kyrgyzstan', 'Laos', 'Latvia', 'Lesotho', 'Lithuania', 'Luxembourg', 'Macedonia', 'Madagascar', 'Malaysia', 'Malta', 'Mexico', 'Monaco', 'Mongolia', 'Montenegro', 'Netherlands', 'New Zealand', 'Nigeria', 'Norway', 'Pakistan', 'Palestine', 'Peru', 'Philippines', 'Poland', 'Portugal', 'Puerto Rico', 'Romania', 'Russia', 'Rwanda','Saudi Arabia', 'Senegal', 'Serbia', 'Singapore', 'Slovakia', 'Slovenia', 'South Africa', 'South Korea', 'Spain', 'Sri Lanka', 'Swaziland', 'Sweden', 'Switzerland', 'Syria','Taiwan', 'Thailand', 'Tunisia', 'Turkey', 'Uganda', 'Ukraine', 'United Arab Emirates', 'United Kingdom', 'United States', 'Uruguay']
14
 
15
- def process_image(image):
16
  # Ensure the image is in the correct format
17
  if isinstance(image, str):
18
  image = Image.open(image)
@@ -28,19 +30,29 @@ def process_image(image):
28
 
29
  # Get the country with the highest probability
30
  country_index = probs.argmax(dim=1).item()
31
- print(country_index)
32
  country = labels[country_index]
33
- print(country)
34
  # Generate music based on the country
 
 
 
 
 
 
 
 
 
35
 
36
  # Save the generated music to the specified path
 
 
37
 
38
  # Return the country and the path to the generated music
39
- return country
40
 
41
  # Define the Gradio interface
42
  inputs = gr.Image(type="pil", label="Upload a photo (تحميل صورة)")
43
- outputs = [gr.Textbox(label="Country (البلد)")]
44
 
45
  iface = gr.Interface(
46
  fn=process_image,
@@ -48,8 +60,8 @@ iface = gr.Interface(
48
  outputs=outputs,
49
  title="Photo to Country and Music Generator محدد الموقع من الصور بالاضافة الى انشاء م",
50
  description="Upload a photo to identify the country and generate traditional music from that country. (قم بتحميل صورة لتحديد البلد وإنشاء موسيقى تقليدية من هذا البلد.)",
51
- examples=["Egypt.jfif", "Riyadh.jpeg", "Syria.jfif", "Turkey.jfif"]
52
  )
53
 
54
  # Launch the interface
55
- iface.launch()
 
1
  import gradio as gr
2
+ from transformers import CLIPProcessor, CLIPModel, pipeline, AutoProcessor, MusicgenForConditionalGeneration
3
  import torch
4
  from PIL import Image
5
  import scipy.io.wavfile
6
 
7
  # Load the MusicGen model
8
+ #musicgen = pipeline("text-to-audio", model="facebook/musicgen-small")
9
+ musicProcessor = AutoProcessor.from_pretrained("facebook/musicgen-small")
10
+ musicgen = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
11
  # Load the StreetCLIP model
12
  model = CLIPModel.from_pretrained("geolocal/StreetCLIP")
13
  processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP")
14
 
15
  labels = ['Albania', 'Andorra', 'Argentina', 'Australia', 'Austria', 'Bangladesh', 'Belgium', 'Bermuda', 'Bhutan', 'Bolivia', 'Botswana', 'Brazil', 'Bulgaria', 'Cambodia', 'Canada', 'Chile', 'China', 'Colombia', 'Croatia', 'Czech Republic', 'Denmark', 'Dominican Republic', 'Egypt', 'Ecuador', 'Estonia', 'Finland', 'France', 'Germany', 'Ghana', 'Greece', 'Greenland', 'Guam', 'Guatemala', 'Hungary', 'Iceland', 'India', 'Indonesia', 'Ireland', 'Israel', 'Italy', 'Japan', 'Jordan', 'Kenya', 'Kyrgyzstan', 'Laos', 'Latvia', 'Lesotho', 'Lithuania', 'Luxembourg', 'Macedonia', 'Madagascar', 'Malaysia', 'Malta', 'Mexico', 'Monaco', 'Mongolia', 'Montenegro', 'Netherlands', 'New Zealand', 'Nigeria', 'Norway', 'Pakistan', 'Palestine', 'Peru', 'Philippines', 'Poland', 'Portugal', 'Puerto Rico', 'Romania', 'Russia', 'Rwanda','Saudi Arabia', 'Senegal', 'Serbia', 'Singapore', 'Slovakia', 'Slovenia', 'South Africa', 'South Korea', 'Spain', 'Sri Lanka', 'Swaziland', 'Sweden', 'Switzerland', 'Syria','Taiwan', 'Thailand', 'Tunisia', 'Turkey', 'Uganda', 'Ukraine', 'United Arab Emirates', 'United Kingdom', 'United States', 'Uruguay']
16
 
17
+ def process_image(image, audio_path="musicgen_out.wav"):
18
  # Ensure the image is in the correct format
19
  if isinstance(image, str):
20
  image = Image.open(image)
 
30
 
31
  # Get the country with the highest probability
32
  country_index = probs.argmax(dim=1).item()
 
33
  country = labels[country_index]
34
+
35
  # Generate music based on the country
36
+ music_description = f"Traditional music from {country}"
37
+ #music = musicgen(music_description, forward_params={"do_sample": True})
38
+ inputs = musicProcessor(
39
+ text=[music_description],
40
+ padding=True,
41
+ return_tensors="pt",
42
+ )
43
+ audio_values = musicgen.generate(**inputs, max_new_tokens=256)
44
+
45
 
46
  # Save the generated music to the specified path
47
+ sampling_rate = model.config.audio_encoder.sampling_rate
48
+ scipy.io.wavfile.write("musicgen_out.wav", rate=sampling_rate, data=audio_values[0, 0].numpy())
49
 
50
  # Return the country and the path to the generated music
51
+ return country, audio_path
52
 
53
  # Define the Gradio interface
54
  inputs = gr.Image(type="pil", label="Upload a photo (تحميل صورة)")
55
+ outputs = [gr.Textbox(label="Country (البلد)"), gr.Audio(label="Generated Music (الموسيقى المولدة)")]
56
 
57
  iface = gr.Interface(
58
  fn=process_image,
 
60
  outputs=outputs,
61
  title="Photo to Country and Music Generator محدد الموقع من الصور بالاضافة الى انشاء م",
62
  description="Upload a photo to identify the country and generate traditional music from that country. (قم بتحميل صورة لتحديد البلد وإنشاء موسيقى تقليدية من هذا البلد.)",
63
+ examples=["/content/Egypt.jfif", "/content/Riyadh.jpeg", "/content/Syria.jfif", "/content/Turkey.jfif"]
64
  )
65
 
66
  # Launch the interface
67
+ iface.launch(debug=True)