sander-wood commited on
Commit
f72d35e
·
verified ·
1 Parent(s): b5c9b89

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +351 -351
app.py CHANGED
@@ -1,351 +1,351 @@
1
- import os
2
- import torch
3
- import numpy as np
4
- import gradio as gr
5
- import zipfile
6
- import json
7
- import requests
8
- import subprocess
9
- import shutil
10
- from transformers import BlipProcessor, BlipForConditionalGeneration
11
-
12
- title = "# 🗜️ CLaMP 3 - Multimodal & Multilingual Semantic Music Search"
13
-
14
- badges = """
15
- <div style="text-align: center;">
16
- <a href="https://sanderwood.github.io/clamp3/">
17
- <img src="https://img.shields.io/badge/CLaMP%203%20Homepage-GitHub-181717?style=for-the-badge&logo=home-assistant" alt="Homepage">
18
- </a>
19
- <a href="https://arxiv.org/abs/2502.10362">
20
- <img src="https://img.shields.io/badge/CLaMP%203%20Paper-Arxiv-red?style=for-the-badge&logo=arxiv" alt="Paper">
21
- </a>
22
- <a href="https://github.com/sanderwood/clamp3">
23
- <img src="https://img.shields.io/badge/CLaMP%203%20Code-GitHub-181717?style=for-the-badge&logo=github" alt="GitHub">
24
- </a>
25
- <a href="https://huggingface.co/spaces/sander-wood/clamp3">
26
- <img src="https://img.shields.io/badge/CLaMP%203%20Demo-Gradio-green?style=for-the-badge&logo=gradio" alt="Demo">
27
- </a>
28
- <a href="https://huggingface.co/sander-wood/clamp3/tree/main">
29
- <img src="https://img.shields.io/badge/Model%20Weights-Hugging%20Face-ffcc00?style=for-the-badge&logo=huggingface" alt="Model Weights">
30
- </a>
31
- <a href="https://huggingface.co/datasets/sander-wood/m4-rag">
32
- <img src="https://img.shields.io/badge/M4--RAG%20Dataset-Hugging%20Face-ffcc00?style=for-the-badge&logo=huggingface" alt="Dataset">
33
- </a>
34
- <a href="https://huggingface.co/datasets/sander-wood/wikimt-x">
35
- <img src="https://img.shields.io/badge/WikiMT--X%20Benchmark-Hugging%20Face-ffcc00?style=for-the-badge&logo=huggingface" alt="Benchmark">
36
- </a>
37
- </div>
38
-
39
- <style>
40
- div a {
41
- display: inline-block;
42
- margin: 5px;
43
- }
44
- div a img {
45
- height: 30px;
46
- }
47
- </style>
48
- """
49
-
50
- description = """CLaMP 3 is a **multimodal and multilingual** music information retrieval (MIR) framework, supporting **sheet music, audio, and performance signals** in **100 languages**. Using **contrastive learning**, it aligns these modalities in a shared space for **cross-modal retrieval**.
51
-
52
- ### 🔍 **How This Demo Works**
53
- - You can **retrieve music using any text input (in any language) or an image** (`.png`, `.jpg`).
54
- - When using an image, **BLIP** generates a caption, which is then used for retrieval.
55
- - Since CLaMP 3's training data includes **rich visual descriptions of musical scenes**, it can **match images to semantically relevant music**.
56
-
57
- ### ⚠️ **Limitations**
58
- - This demo retrieves music **only from the WikiMT-X benchmark (1,000 pieces)**.
59
- - These pieces are **mainly from the U.S. and Western Europe (especially the U.S.)** and **mostly from the 20th century**.
60
- - Thus, retrieval results are **mostly limited to Western 20th-century music**, so you **won’t** find music from **other regions or historical periods**.
61
-
62
- 🔧 **Need retrieval for a different music collection?** Deploy **[CLaMP 3](https://github.com/sanderwood/clamp3)** on your own dataset.
63
- Generally, the larger and more diverse the reference music dataset, the better the retrieval quality, increasing the likelihood of finding relevant and accurately matched music.
64
-
65
- **Note: This project is for research use only.**
66
- """
67
-
68
- # Load BLIP image captioning model and processor
69
- processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
70
- blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
71
-
72
- # Download weight file if it does not exist
73
- weights_url = "https://huggingface.co/sander-wood/clamp3/resolve/main/weights_clamp3_saas_h_size_768_t_model_FacebookAI_xlm-roberta-base_t_length_128_a_size_768_a_layers_12_a_length_128_s_size_768_s_layers_12_p_size_64_p_length_512.pth"
74
- weights_filename = "weights_clamp3_saas_h_size_768_t_model_FacebookAI_xlm-roberta-base_t_length_128_a_size_768_a_layers_12_a_length_128_s_size_768_s_layers_12_p_size_64_p_length_512.pth"
75
-
76
- if not os.path.exists(weights_filename):
77
- print("Downloading weights file...")
78
- response = requests.get(weights_url, stream=True)
79
- response.raise_for_status()
80
- with open(weights_filename, "wb") as f:
81
- for chunk in response.iter_content(chunk_size=8192):
82
- if chunk:
83
- f.write(chunk)
84
- print("Weights file downloaded.")
85
-
86
- ZIP_PATH = "features.zip"
87
- if os.path.exists(ZIP_PATH):
88
- print(f"Extracting {ZIP_PATH}...")
89
- with zipfile.ZipFile(ZIP_PATH, "r") as zip_ref:
90
- zip_ref.extractall(".")
91
- print("Extraction complete.")
92
-
93
- # Load metadata
94
- metadata_map = {}
95
- METADATA_FILE = "wikimt-x-public.jsonl"
96
- if os.path.exists(METADATA_FILE):
97
- with open(METADATA_FILE, "r", encoding="utf-8") as f:
98
- for line in f:
99
- data = json.loads(line)
100
- metadata_map[data["id"]] = data
101
- else:
102
- print(f"Warning: {METADATA_FILE} not found.")
103
-
104
- features_cache = {}
105
-
106
- def get_info(folder_path):
107
- """
108
- Load all .npy files from the specified folder and return a dictionary
109
- with the file names (without extension) as keys.
110
- """
111
- if folder_path in features_cache:
112
- return features_cache[folder_path]
113
- if not os.path.exists(folder_path):
114
- return {}
115
- files = sorted(os.listdir(folder_path))
116
- features = {}
117
- for file in files:
118
- if file.endswith(".npy"):
119
- key = file.split(".")[0]
120
- try:
121
- features[key] = np.load(os.path.join(folder_path, file))[0]
122
- except Exception as e:
123
- print(f"Error loading {file}: {e}")
124
- features_cache[folder_path] = features
125
- return features
126
-
127
- def find_top_similar(query_file, reference_folder):
128
- """
129
- Compare the query feature with all reference features in the specified folder
130
- using cosine similarity and return the top 10 candidate results in the format:
131
- Title | Artists | sim: SimilarityScore.
132
- """
133
- top_k = 10
134
- try:
135
- query_feature = np.load(query_file.name)[0]
136
- except Exception as e:
137
- return [], f"Error loading query feature: {e}"
138
- query_tensor = torch.tensor(query_feature, dtype=torch.float32).unsqueeze(dim=0)
139
- key_features = get_info(reference_folder)
140
- if not key_features:
141
- return [], f"No reference features found in {reference_folder}."
142
- ref_keys = list(key_features.keys())
143
- ref_array = np.array([key_features[k] for k in ref_keys])
144
- key_feats_tensor = torch.tensor(ref_array, dtype=torch.float32)
145
- query_tensor_expanded = query_tensor.expand(key_feats_tensor.size(0), -1)
146
- similarities = torch.cosine_similarity(query_tensor_expanded, key_feats_tensor, dim=1)
147
- ranked_indices = torch.argsort(similarities, descending=True)
148
- candidate_ids = []
149
- candidate_display = []
150
- for i in range(top_k):
151
- if i < len(ref_keys):
152
- candidate_idx = ranked_indices[i].item()
153
- candidate_id = ref_keys[candidate_idx]
154
- sim = round(similarities[candidate_idx].item(), 4)
155
- meta = metadata_map.get(candidate_id, {})
156
- title = meta.get("title", candidate_id)
157
- artists = meta.get("artists", "Unknown")
158
- if isinstance(artists, list):
159
- artists = ", ".join(artists)
160
- candidate_ids.append(candidate_id)
161
- candidate_display.append(f"{title} | {artists} | sim: {sim}")
162
- else:
163
- candidate_ids.append("N/A")
164
- candidate_display.append("N/A")
165
- return candidate_ids, candidate_display
166
-
167
- def show_details(selected_id):
168
- """
169
- Return detailed metadata and embedded YouTube video HTML based on the candidate ID.
170
- """
171
- if selected_id == "N/A":
172
- return ("", "", "", "", "", "", "", "")
173
- data = metadata_map.get(selected_id, {})
174
- if not data:
175
- return ("No details found", "", "", "", "", "", "", "")
176
- title = data.get("title", "")
177
- artists = data.get("artists", "")
178
- if isinstance(artists, list):
179
- artists = ", ".join(artists)
180
- genre = data.get("genre", "")
181
- background = data.get("background", "")
182
- analysis = data.get("analysis", "")
183
- description = data.get("description", "")
184
- scene = data.get("scene", "")
185
- youtube_html = (
186
- f'<iframe width="560" height="315" src="https://www.youtube.com/embed/{selected_id}" '
187
- f'frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; '
188
- f'gyroscope; picture-in-picture" allowfullscreen></iframe>'
189
- )
190
- return title, artists, genre, background, analysis, description, scene, youtube_html
191
-
192
- def extract_features_from_text(text):
193
- """
194
- Save the input text to a file, call the CLaMP 3 feature extraction script,
195
- and return the generated feature file path.
196
- """
197
- input_dir = "input_dir"
198
- output_dir = "output_dir"
199
- os.makedirs(input_dir, exist_ok=True)
200
- os.makedirs(output_dir, exist_ok=True)
201
- # Clear input_dir and output_dir
202
- for d in [input_dir, output_dir]:
203
- for filename in os.listdir(d):
204
- file_path = os.path.join(d, filename)
205
- if os.path.isfile(file_path) or os.path.islink(file_path):
206
- os.unlink(file_path)
207
- elif os.path.isdir(file_path):
208
- shutil.rmtree(file_path)
209
- input_file = os.path.join(input_dir, "input.txt")
210
- print("Text input:", text)
211
- with open(input_file, "w", encoding="utf-8") as f:
212
- f.write(text)
213
- command = ["python", "extract_clamp3.py", input_dir, output_dir, "--get_global"]
214
- subprocess.run(command, check=True)
215
- output_file = os.path.join(output_dir, "input.npy")
216
- return output_file
217
-
218
- def generate_caption(image):
219
- """
220
- Use the BLIP model to generate a descriptive caption for the given image.
221
- """
222
- inputs = processor(image, return_tensors="pt")
223
- outputs = blip_model.generate(**inputs)
224
- caption = processor.decode(outputs[0], skip_special_tokens=True)
225
- return caption
226
-
227
- class FileWrapper:
228
- """
229
- Simulate a file object with a .name attribute.
230
- """
231
- def __init__(self, path):
232
- self.name = path
233
-
234
- def search_wrapper(search_mode, text_input, image_input):
235
- """
236
- Perform retrieval based on the selected input mode:
237
- - If search_mode is "Image", use the uploaded image to generate a caption, then extract features
238
- and search in the "image/" folder.
239
- - If search_mode is "Text", use the provided text to extract features and search in the "image/" folder.
240
- """
241
- if search_mode == "Image":
242
- if image_input is None:
243
- return text_input, gr.update(choices=[]), "Please upload an image.", "", "", "", "", "", "", ""
244
- caption = generate_caption(image_input)
245
- text_to_use = caption
246
- reference_folder = "image/"
247
- elif search_mode == "Text":
248
- if not text_input or text_input.strip() == "":
249
- return "Describe the music you're looking for (in any language)", gr.update(choices=[]), "Please enter text for retrieval.", "", "", "", "", "", "", ""
250
- text_to_use = text_input
251
- reference_folder = "text/"
252
- else:
253
- return "Describe the music you're looking for (in any language)", gr.update(choices=[]), "Invalid search mode selected.", "", "", "", "", "", "", ""
254
-
255
- try:
256
- output_file = extract_features_from_text(text_to_use)
257
- query_file = FileWrapper(output_file)
258
- except Exception as e:
259
- return text_to_use, gr.update(choices=[]), f"Error during feature extraction: {e}", "", "", "", "", "", "", ""
260
- candidate_ids, candidate_display = find_top_similar(query_file, reference_folder)
261
- if not candidate_ids:
262
- return text_to_use, gr.update(choices=[]), "", "", "", "", "", "", "", ""
263
- choices = [(f"{i+1}. {disp}", cid) for i, (cid, disp) in enumerate(zip(candidate_ids, candidate_display))]
264
- top_candidate = candidate_ids[0]
265
- details = show_details(top_candidate)
266
- return text_to_use, gr.update(choices=choices), *details
267
-
268
- # 定义示例数据(示例数据放在组件定义之后也可以正常运行)
269
- examples = [
270
- ["Image", None, "V4EauuhVEw4.jpg"],
271
- ["Image", None, "Kw-_Ew5bVxs.jpg"],
272
- ["Image", None, "BuYf0taXoNw.webp"],
273
- ["Image", None, "4tDYMayp6Dk.jpg"],
274
- ["Text", "classic rock, British, 1960s, upbeat", None],
275
- ["Text", "A Latin jazz piece with rhythmic percussion and brass", None],
276
- ["Text", "big band, major key, swing, brass-heavy, syncopation, baritone vocal", None],
277
- ["Text", "Heartfelt and nostalgic, with a bittersweet, melancholic feel", None],
278
- ["Text", "Melodía instrumental en re mayor con progresión armónica repetitiva y fluida", None],
279
- ["Text", "D大调四四拍的爱尔兰舞曲", None],
280
- ["Text", "Ιερή μουσική με πνευματική ατμόσφαιρα", None],
281
- ["Text", "የፍቅር ሙዚቃ ሞቅ እና ስሜታማ ከሆነ ነገር ግን ድንቅ እና አስደሳች ቃላት ያካትታል", None],
282
- ]
283
-
284
- with gr.Blocks() as demo:
285
- gr.Markdown(title)
286
- gr.HTML(badges)
287
- gr.Markdown(description)
288
-
289
- with gr.Row():
290
- with gr.Column():
291
- search_mode = gr.Radio(
292
- choices=["Text", "Image"],
293
- label="Select Search Mode",
294
- value="Text",
295
- interactive=True,
296
- elem_classes=["vertical-radio"]
297
- )
298
- text_input = gr.Textbox(
299
- placeholder="Describe the music you're looking for (in any language)",
300
- lines=4
301
- )
302
- image_input = gr.Image(
303
- label="Or upload an image (PNG, JPG)",
304
- type="pil"
305
- )
306
- search_button = gr.Button("Search")
307
- candidate_radio = gr.Radio(choices=[], label="Select Retrieval Result", interactive=True, elem_classes=["vertical-radio"])
308
- with gr.Column():
309
- gr.Markdown("### YouTube Video")
310
- youtube_box = gr.HTML(label="YouTube Video")
311
- gr.Markdown("### Metadata")
312
- title_box = gr.Textbox(label="Title", interactive=False)
313
- artists_box = gr.Textbox(label="Artists", interactive=False)
314
- genre_box = gr.Textbox(label="Genre", interactive=False)
315
- background_box = gr.Textbox(label="Background", interactive=False)
316
- analysis_box = gr.Textbox(label="Analysis", interactive=False)
317
- description_box = gr.Textbox(label="Description", interactive=False)
318
- scene_box = gr.Textbox(label="Scene", interactive=False)
319
-
320
- gr.HTML(
321
- """
322
- <style>
323
- .vertical-radio .gradio-radio label {
324
- display: block !important;
325
- margin-bottom: 5px;
326
- }
327
- </style>
328
- """
329
- )
330
-
331
- gr.Examples(
332
- examples=examples,
333
- inputs=[search_mode, text_input, image_input],
334
- outputs=[text_input, candidate_radio, title_box, artists_box, genre_box, background_box, analysis_box, description_box, scene_box, youtube_box],
335
- fn=search_wrapper,
336
- cache_examples=False,
337
- )
338
-
339
- search_button.click(
340
- fn=search_wrapper,
341
- inputs=[search_mode, text_input, image_input],
342
- outputs=[text_input, candidate_radio, title_box, artists_box, genre_box, background_box, analysis_box, description_box, scene_box, youtube_box]
343
- )
344
-
345
- candidate_radio.change(
346
- fn=show_details,
347
- inputs=candidate_radio,
348
- outputs=[title_box, artists_box, genre_box, background_box, analysis_box, description_box, scene_box, youtube_box]
349
- )
350
-
351
- demo.launch()
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import gradio as gr
5
+ import zipfile
6
+ import json
7
+ import requests
8
+ import subprocess
9
+ import shutil
10
+ from transformers import BlipProcessor, BlipForConditionalGeneration
11
+
12
+ title = "# 🗜️ CLaMP 3 - Multimodal & Multilingual Semantic Music Search"
13
+
14
+ badges = """
15
+ <div style="text-align: center;">
16
+ <a href="https://sanderwood.github.io/clamp3/">
17
+ <img src="https://img.shields.io/badge/CLaMP%203%20Homepage-GitHub-181717?style=for-the-badge&logo=home-assistant" alt="Homepage">
18
+ </a>
19
+ <a href="https://arxiv.org/abs/2502.10362">
20
+ <img src="https://img.shields.io/badge/CLaMP%203%20Paper-Arxiv-red?style=for-the-badge&logo=arxiv" alt="Paper">
21
+ </a>
22
+ <a href="https://github.com/sanderwood/clamp3">
23
+ <img src="https://img.shields.io/badge/CLaMP%203%20Code-GitHub-181717?style=for-the-badge&logo=github" alt="GitHub">
24
+ </a>
25
+ <a href="https://huggingface.co/spaces/sander-wood/clamp3">
26
+ <img src="https://img.shields.io/badge/CLaMP%203%20Demo-Gradio-green?style=for-the-badge&logo=gradio" alt="Demo">
27
+ </a>
28
+ <a href="https://huggingface.co/sander-wood/clamp3/tree/main">
29
+ <img src="https://img.shields.io/badge/Model%20Weights-Hugging%20Face-ffcc00?style=for-the-badge&logo=huggingface" alt="Model Weights">
30
+ </a>
31
+ <a href="https://huggingface.co/datasets/sander-wood/m4-rag">
32
+ <img src="https://img.shields.io/badge/M4--RAG%20Dataset-Hugging%20Face-ffcc00?style=for-the-badge&logo=huggingface" alt="Dataset">
33
+ </a>
34
+ <a href="https://huggingface.co/datasets/sander-wood/wikimt-x">
35
+ <img src="https://img.shields.io/badge/WikiMT--X%20Benchmark-Hugging%20Face-ffcc00?style=for-the-badge&logo=huggingface" alt="Benchmark">
36
+ </a>
37
+ </div>
38
+
39
+ <style>
40
+ div a {
41
+ display: inline-block;
42
+ margin: 5px;
43
+ }
44
+ div a img {
45
+ height: 30px;
46
+ }
47
+ </style>
48
+ """
49
+
50
+ description = """CLaMP 3 is a **multimodal and multilingual** music information retrieval (MIR) framework, supporting **sheet music, audio, and performance signals** in **100 languages**. Using **contrastive learning**, it aligns these modalities in a shared space for **cross-modal retrieval**.
51
+
52
+ ### 🔍 **How This Demo Works**
53
+ - You can **retrieve music using any text input (in any language) or an image** (`.png`, `.jpg`).
54
+ - When using an image, **BLIP** generates a caption, which is then used for retrieval.
55
+ - Since CLaMP 3's training data includes **rich visual descriptions of musical scenes**, it can **match images to semantically relevant music**.
56
+
57
+ ### ⚠️ **Limitations**
58
+ - This demo retrieves music **only from the WikiMT-X benchmark (1,000 pieces)**.
59
+ - These pieces are **mainly from the U.S. and Western Europe (especially the U.S.)** and **mostly from the 20th century**.
60
+ - Thus, retrieval results are **mostly limited to Western 20th-century music**, so you **won’t** find music from **other regions or historical periods**.
61
+
62
+ 🔧 **Need retrieval for a different music collection?** Deploy **[CLaMP 3](https://github.com/sanderwood/clamp3)** on your own dataset.
63
+ Generally, the larger and more diverse the reference music dataset, the better the retrieval quality, increasing the likelihood of finding relevant and accurately matched music.
64
+
65
+ **Note: This project is for research use only.**
66
+ """
67
+
68
+ # Load BLIP image captioning model and processor
69
+ processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
70
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
71
+
72
+ # Download weight file if it does not exist
73
+ weights_url = "https://huggingface.co/sander-wood/clamp3/resolve/main/weights_clamp3_saas_h_size_768_t_model_FacebookAI_xlm-roberta-base_t_length_128_a_size_768_a_layers_12_a_length_128_s_size_768_s_layers_12_p_size_64_p_length_512.pth"
74
+ weights_filename = "weights_clamp3_saas_h_size_768_t_model_FacebookAI_xlm-roberta-base_t_length_128_a_size_768_a_layers_12_a_length_128_s_size_768_s_layers_12_p_size_64_p_length_512.pth"
75
+
76
+ if not os.path.exists(weights_filename):
77
+ print("Downloading weights file...")
78
+ response = requests.get(weights_url, stream=True)
79
+ response.raise_for_status()
80
+ with open(weights_filename, "wb") as f:
81
+ for chunk in response.iter_content(chunk_size=8192):
82
+ if chunk:
83
+ f.write(chunk)
84
+ print("Weights file downloaded.")
85
+
86
+ ZIP_PATH = "features.zip"
87
+ if os.path.exists(ZIP_PATH):
88
+ print(f"Extracting {ZIP_PATH}...")
89
+ with zipfile.ZipFile(ZIP_PATH, "r") as zip_ref:
90
+ zip_ref.extractall(".")
91
+ print("Extraction complete.")
92
+
93
+ # Load metadata
94
+ metadata_map = {}
95
+ METADATA_FILE = "wikimt-x-public.jsonl"
96
+ if os.path.exists(METADATA_FILE):
97
+ with open(METADATA_FILE, "r", encoding="utf-8") as f:
98
+ for line in f:
99
+ data = json.loads(line)
100
+ metadata_map[data["id"]] = data
101
+ else:
102
+ print(f"Warning: {METADATA_FILE} not found.")
103
+
104
+ features_cache = {}
105
+
106
+ def get_info(folder_path):
107
+ """
108
+ Load all .npy files from the specified folder and return a dictionary
109
+ with the file names (without extension) as keys.
110
+ """
111
+ if folder_path in features_cache:
112
+ return features_cache[folder_path]
113
+ if not os.path.exists(folder_path):
114
+ return {}
115
+ files = sorted(os.listdir(folder_path))
116
+ features = {}
117
+ for file in files:
118
+ if file.endswith(".npy"):
119
+ key = file.split(".")[0]
120
+ try:
121
+ features[key] = np.load(os.path.join(folder_path, file))[0]
122
+ except Exception as e:
123
+ print(f"Error loading {file}: {e}")
124
+ features_cache[folder_path] = features
125
+ return features
126
+
127
+ def find_top_similar(query_file, reference_folder):
128
+ """
129
+ Compare the query feature with all reference features in the specified folder
130
+ using cosine similarity and return the top 10 candidate results in the format:
131
+ Title | Artists | sim: SimilarityScore.
132
+ """
133
+ top_k = 10
134
+ try:
135
+ query_feature = np.load(query_file.name)[0]
136
+ except Exception as e:
137
+ return [], f"Error loading query feature: {e}"
138
+ query_tensor = torch.tensor(query_feature, dtype=torch.float32).unsqueeze(dim=0)
139
+ key_features = get_info(reference_folder)
140
+ if not key_features:
141
+ return [], f"No reference features found in {reference_folder}."
142
+ ref_keys = list(key_features.keys())
143
+ ref_array = np.array([key_features[k] for k in ref_keys])
144
+ key_feats_tensor = torch.tensor(ref_array, dtype=torch.float32)
145
+ query_tensor_expanded = query_tensor.expand(key_feats_tensor.size(0), -1)
146
+ similarities = torch.cosine_similarity(query_tensor_expanded, key_feats_tensor, dim=1)
147
+ ranked_indices = torch.argsort(similarities, descending=True)
148
+ candidate_ids = []
149
+ candidate_display = []
150
+ for i in range(top_k):
151
+ if i < len(ref_keys):
152
+ candidate_idx = ranked_indices[i].item()
153
+ candidate_id = ref_keys[candidate_idx]
154
+ sim = round(similarities[candidate_idx].item(), 4)
155
+ meta = metadata_map.get(candidate_id, {})
156
+ title = meta.get("title", candidate_id)
157
+ artists = meta.get("artists", "Unknown")
158
+ if isinstance(artists, list):
159
+ artists = ", ".join(artists)
160
+ candidate_ids.append(candidate_id)
161
+ candidate_display.append(f"{title} | {artists} | sim: {sim}")
162
+ else:
163
+ candidate_ids.append("N/A")
164
+ candidate_display.append("N/A")
165
+ return candidate_ids, candidate_display
166
+
167
+ def show_details(selected_id):
168
+ """
169
+ Return detailed metadata and embedded YouTube video HTML based on the candidate ID.
170
+ """
171
+ if selected_id == "N/A":
172
+ return ("", "", "", "", "", "", "", "")
173
+ data = metadata_map.get(selected_id, {})
174
+ if not data:
175
+ return ("No details found", "", "", "", "", "", "", "")
176
+ title = data.get("title", "")
177
+ artists = data.get("artists", "")
178
+ if isinstance(artists, list):
179
+ artists = ", ".join(artists)
180
+ genre = data.get("genre", "")
181
+ background = data.get("background", "")
182
+ analysis = data.get("analysis", "")
183
+ description = data.get("description", "")
184
+ scene = data.get("scene", "")
185
+ youtube_html = (
186
+ f'<iframe width="560" height="315" src="https://www.youtube.com/embed/{selected_id}" '
187
+ f'frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; '
188
+ f'gyroscope; picture-in-picture" allowfullscreen></iframe>'
189
+ )
190
+ return title, artists, genre, background, analysis, description, scene, youtube_html
191
+
192
+ def extract_features_from_text(text):
193
+ """
194
+ Save the input text to a file, call the CLaMP 3 feature extraction script,
195
+ and return the generated feature file path.
196
+ """
197
+ input_dir = "input_dir"
198
+ output_dir = "output_dir"
199
+ os.makedirs(input_dir, exist_ok=True)
200
+ os.makedirs(output_dir, exist_ok=True)
201
+ # Clear input_dir and output_dir
202
+ for d in [input_dir, output_dir]:
203
+ for filename in os.listdir(d):
204
+ file_path = os.path.join(d, filename)
205
+ if os.path.isfile(file_path) or os.path.islink(file_path):
206
+ os.unlink(file_path)
207
+ elif os.path.isdir(file_path):
208
+ shutil.rmtree(file_path)
209
+ input_file = os.path.join(input_dir, "input.txt")
210
+ print("Text input:", text)
211
+ with open(input_file, "w", encoding="utf-8") as f:
212
+ f.write(text)
213
+ command = ["python", "extract_clamp3.py", input_dir, output_dir, "--get_global"]
214
+ subprocess.run(command, check=True)
215
+ output_file = os.path.join(output_dir, "input.npy")
216
+ return output_file
217
+
218
+ def generate_caption(image):
219
+ """
220
+ Use the BLIP model to generate a descriptive caption for the given image.
221
+ """
222
+ inputs = processor(image, return_tensors="pt")
223
+ outputs = blip_model.generate(**inputs)
224
+ caption = processor.decode(outputs[0], skip_special_tokens=True)
225
+ return caption
226
+
227
+ class FileWrapper:
228
+ """
229
+ Simulate a file object with a .name attribute.
230
+ """
231
+ def __init__(self, path):
232
+ self.name = path
233
+
234
+ def search_wrapper(search_mode, text_input, image_input):
235
+ """
236
+ Perform retrieval based on the selected input mode:
237
+ - If search_mode is "Image", use the uploaded image to generate a caption, then extract features
238
+ and search in the "image/" folder.
239
+ - If search_mode is "Text", use the provided text to extract features and search in the "image/" folder.
240
+ """
241
+ if search_mode == "Image":
242
+ if image_input is None:
243
+ return text_input, gr.update(choices=[]), "Please upload an image.", "", "", "", "", "", "", ""
244
+ caption = generate_caption(image_input)
245
+ text_to_use = caption
246
+ reference_folder = "image/"
247
+ elif search_mode == "Text":
248
+ if not text_input or text_input.strip() == "":
249
+ return "Describe the music you're looking for (in any language)", gr.update(choices=[]), "Please enter text for retrieval.", "", "", "", "", "", "", ""
250
+ text_to_use = text_input
251
+ reference_folder = "text/"
252
+ else:
253
+ return "Describe the music you're looking for (in any language)", gr.update(choices=[]), "Invalid search mode selected.", "", "", "", "", "", "", ""
254
+
255
+ try:
256
+ output_file = extract_features_from_text(text_to_use)
257
+ query_file = FileWrapper(output_file)
258
+ except Exception as e:
259
+ return text_to_use, gr.update(choices=[]), f"Error during feature extraction: {e}", "", "", "", "", "", "", ""
260
+ candidate_ids, candidate_display = find_top_similar(query_file, reference_folder)
261
+ if not candidate_ids:
262
+ return text_to_use, gr.update(choices=[]), "", "", "", "", "", "", "", ""
263
+ choices = [(f"{i+1}. {disp}", cid) for i, (cid, disp) in enumerate(zip(candidate_ids, candidate_display))]
264
+ top_candidate = candidate_ids[0]
265
+ details = show_details(top_candidate)
266
+ return text_to_use, gr.update(choices=choices), *details
267
+
268
+ # 定义示例数据(示例数据放在组件定义之后也可以正常运行)
269
+ examples = [
270
+ ["Image", None, "V4EauuhVEw4.jpg"],
271
+ ["Image", None, "Kw-_Ew5bVxs.jpg"],
272
+ ["Image", None, "BuYf0taXoNw.webp"],
273
+ ["Image", None, "4tDYMayp6Dk.jpg"],
274
+ ["Text", "classic rock, British, 1960s, upbeat", None],
275
+ ["Text", "A Latin jazz piece with rhythmic percussion and brass", None],
276
+ ["Text", "big band, major key, swing, brass-heavy, syncopation, baritone vocal", None],
277
+ ["Text", "Heartfelt and nostalgic, with a bittersweet, melancholic feel", None],
278
+ ["Text", "Melodía instrumental en re mayor con progresión armónica repetitiva y fluida", None],
279
+ ["Text", "D大调四四拍的爱尔兰舞曲", None],
280
+ ["Text", "Ιερή μουσική με πνευματική ατμόσφαιρα", None],
281
+ ["Text", "የፍቅር ሙዚቃ ሞቅ እና ስሜታማ ከሆነ ነገር ግን ድንቅ እና አስደሳች ቃላት ያካትታል", None],
282
+ ]
283
+
284
+ with gr.Blocks() as demo:
285
+ gr.Markdown(title)
286
+ gr.HTML(badges)
287
+ gr.Markdown(description)
288
+
289
+ with gr.Row():
290
+ with gr.Column():
291
+ search_mode = gr.Radio(
292
+ choices=["Text", "Image"],
293
+ label="Select Search Mode",
294
+ value="Text",
295
+ interactive=True,
296
+ elem_classes=["vertical-radio"]
297
+ )
298
+ text_input = gr.Textbox(
299
+ placeholder="Describe the music you're looking for (in any language)",
300
+ lines=4
301
+ )
302
+ image_input = gr.Image(
303
+ label="Or upload an image (PNG, JPG)",
304
+ type="pil"
305
+ )
306
+ search_button = gr.Button("Search from 1,000 Western 20th-century music in WikiMT-X")
307
+ candidate_radio = gr.Radio(choices=[], label="Select Retrieval Result", interactive=True, elem_classes=["vertical-radio"])
308
+ with gr.Column():
309
+ gr.Markdown("### YouTube Video")
310
+ youtube_box = gr.HTML(label="YouTube Video")
311
+ gr.Markdown("### Metadata")
312
+ title_box = gr.Textbox(label="Title", interactive=False)
313
+ artists_box = gr.Textbox(label="Artists", interactive=False)
314
+ genre_box = gr.Textbox(label="Genre", interactive=False)
315
+ background_box = gr.Textbox(label="Background", interactive=False)
316
+ analysis_box = gr.Textbox(label="Analysis", interactive=False)
317
+ description_box = gr.Textbox(label="Description", interactive=False)
318
+ scene_box = gr.Textbox(label="Scene", interactive=False)
319
+
320
+ gr.HTML(
321
+ """
322
+ <style>
323
+ .vertical-radio .gradio-radio label {
324
+ display: block !important;
325
+ margin-bottom: 5px;
326
+ }
327
+ </style>
328
+ """
329
+ )
330
+
331
+ gr.Examples(
332
+ examples=examples,
333
+ inputs=[search_mode, text_input, image_input],
334
+ outputs=[text_input, candidate_radio, title_box, artists_box, genre_box, background_box, analysis_box, description_box, scene_box, youtube_box],
335
+ fn=search_wrapper,
336
+ cache_examples=False,
337
+ )
338
+
339
+ search_button.click(
340
+ fn=search_wrapper,
341
+ inputs=[search_mode, text_input, image_input],
342
+ outputs=[text_input, candidate_radio, title_box, artists_box, genre_box, background_box, analysis_box, description_box, scene_box, youtube_box]
343
+ )
344
+
345
+ candidate_radio.change(
346
+ fn=show_details,
347
+ inputs=candidate_radio,
348
+ outputs=[title_box, artists_box, genre_box, background_box, analysis_box, description_box, scene_box, youtube_box]
349
+ )
350
+
351
+ demo.launch()