File size: 3,215 Bytes
cad6415
 
 
 
 
 
 
f8cf759
cad6415
 
 
 
 
 
 
9f5c8e1
cad6415
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8cf759
 
cad6415
 
9f5c8e1
cad6415
 
f8cf759
cad6415
 
 
f8cf759
cad6415
 
 
 
 
 
 
26cb65e
cad6415
 
 
14a7eda
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import gradio as gr
from transformers import CLIPProcessor, CLIPModel, pipeline
import torch
from PIL import Image
import scipy.io.wavfile

# Load the MusicGen model
#musicgen = pipeline("text-to-audio", model="facebook/musicgen-small")

# Load the StreetCLIP model
model = CLIPModel.from_pretrained("geolocal/StreetCLIP")
processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP")

labels = ['Albania', 'Andorra', 'Argentina', 'Australia', 'Austria', 'Bangladesh', 'Belgium', 'Bermuda', 'Bhutan', 'Bolivia', 'Botswana', 'Brazil', 'Bulgaria', 'Cambodia', 'Canada', 'Chile', 'China', 'Colombia', 'Croatia', 'Czech Republic', 'Denmark', 'Dominican Republic', 'Egypt', 'Ecuador', 'Estonia', 'Finland', 'France', 'Germany', 'Ghana', 'Greece', 'Greenland', 'Guam', 'Guatemala', 'Hungary', 'Iceland', 'India', 'Indonesia', 'Ireland', 'Israel', 'Italy', 'Japan', 'Jordan', 'Kenya', 'Kyrgyzstan', 'Laos', 'Latvia', 'Lesotho', 'Lithuania', 'Luxembourg', 'Macedonia', 'Madagascar', 'Malaysia', 'Malta', 'Mexico', 'Monaco', 'Mongolia', 'Montenegro', 'Netherlands', 'New Zealand', 'Nigeria', 'Norway', 'Pakistan', 'Palestine', 'Peru', 'Philippines', 'Poland', 'Portugal', 'Puerto Rico', 'Romania', 'Russia', 'Rwanda','Saudi Arabia', 'Senegal', 'Serbia', 'Singapore', 'Slovakia', 'Slovenia', 'South Africa', 'South Korea', 'Spain', 'Sri Lanka', 'Swaziland', 'Sweden', 'Switzerland', 'Syria','Taiwan', 'Thailand', 'Tunisia', 'Turkey', 'Uganda', 'Ukraine', 'United Arab Emirates', 'United Kingdom', 'United States', 'Uruguay']

def process_image(image):
    # Ensure the image is in the correct format
    if isinstance(image, str):
        image = Image.open(image)

    # Process the image and text inputs
    inputs = processor(text=labels, images=image, return_tensors="pt", padding=True)

    # Get the model outputs
    with torch.no_grad():
        outputs = model(**inputs)
    logits_per_image = outputs.logits_per_image
    probs = logits_per_image.softmax(dim=1)

    # Get the country with the highest probability
    country_index = probs.argmax(dim=1).item()
    country = labels[country_index]

    # Generate music based on the country
    #music_description = f"Traditional music from {country}"
    #music = musicgen(music_description, forward_params={"do_sample": True})

    # Save the generated music to the specified path
    #scipy.io.wavfile.write(audio_path, rate=music["sampling_rate"], data=music["audio"])

    # Return the country and the path to the generated music
    return country

# Define the Gradio interface
inputs = gr.Image(type="pil", label="Upload a photo (تحميل صورة)")
outputs = [gr.Textbox(label="Country (البلد)")]

iface = gr.Interface(
    fn=process_image,
    inputs=inputs,
    outputs=outputs,
    title="Photo to Country and Music Generator محدد الموقع من الصور بالاضافة الى انشاء م",
    description="Upload a photo to identify the country and generate traditional music from that country. (قم بتحميل صورة لتحديد البلد وإنشاء موسيقى تقليدية من هذا البلد.)",
    examples=["Egypt.jfif", "Riyadh.jpeg", "Syria.jfif", "Turkey.jfif"]
)

# Launch the interface
iface.launch()