Spaces:
Runtime error
Runtime error
Commit
·
043d857
1
Parent(s):
2570d24
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from sentence_transformers import SentenceTransformer, util
|
2 |
+
from huggingface_hub import hf_hub_download
|
3 |
+
import pickle
|
4 |
+
import pandas as pd
|
5 |
+
import gradio as gr
|
6 |
+
|
7 |
+
pd.options.mode.chained_assignment = None # Turn off SettingWithCopyWarning
|
8 |
+
|
9 |
+
pickled = pickle.load(open(hf_hub_download("NimaBoscarino/playlist-generator", filename="clean-large_embeddings_msmarco-MiniLM-L-6-v3.pkl"), "rb"))
|
10 |
+
songs = pd.read_csv(hf_hub_download("NimaBoscarino/playlist-generator", filename="songs_new.csv"))
|
11 |
+
verses = pickle.load(open(hf_hub_download("NimaBoscarino/playlist-generator", filename="verses.pkl"), "rb"))
|
12 |
+
lyrics = pd.read_csv(hf_hub_download("NimaBoscarino/playlist-generator", filename="lyrics_new.csv"))
|
13 |
+
|
14 |
+
embedder = SentenceTransformer('msmarco-MiniLM-L-6-v3')
|
15 |
+
|
16 |
+
genius_ids = pickled["genius_ids"]
|
17 |
+
corpus_embeddings = pickled["embeddings"]
|
18 |
+
|
19 |
+
|
20 |
+
def generate_playlist(prompt):
|
21 |
+
prompt_embedding = embedder.encode(prompt, convert_to_tensor=True)
|
22 |
+
hits = util.semantic_search(prompt_embedding, corpus_embeddings, top_k=20)
|
23 |
+
hits = pd.DataFrame(hits[0], columns=['corpus_id', 'score'])
|
24 |
+
|
25 |
+
verse_match = verses.iloc[hits['corpus_id']]
|
26 |
+
verse_match = verse_match.drop_duplicates(subset=["genius_id"])
|
27 |
+
song_match = songs[songs["genius_id"].isin(verse_match["genius_id"].values)]
|
28 |
+
song_match.genius_id = pd.Categorical(song_match.genius_id, categories=verse_match["genius_id"].values)
|
29 |
+
song_match = song_match.sort_values("genius_id")
|
30 |
+
song_match = song_match[0:9] # Only grab the top 9
|
31 |
+
|
32 |
+
song_names = list(song_match["full_title"])
|
33 |
+
song_art = list(song_match["art"].fillna("https://i.imgur.com/bgCDfT1.jpg"))
|
34 |
+
images = [gr.Image.update(value=art, visible=True) for art in song_art]
|
35 |
+
|
36 |
+
return (
|
37 |
+
gr.Radio.update(label="Songs", interactive=True, choices=song_names),
|
38 |
+
*images
|
39 |
+
)
|
40 |
+
|
41 |
+
|
42 |
+
def set_lyrics(full_title):
|
43 |
+
lyrics_text = lyrics[lyrics["genius_id"].isin(songs[songs["full_title"] == full_title]["genius_id"])]["text"].iloc[0]
|
44 |
+
return gr.Textbox.update(value=lyrics_text)
|
45 |
+
|
46 |
+
|
47 |
+
def set_example_prompt(example):
|
48 |
+
return gr.TextArea.update(value=example[0])
|
49 |
+
|
50 |
+
|
51 |
+
demo = gr.Blocks()
|
52 |
+
|
53 |
+
with demo:
|
54 |
+
gr.Markdown(
|
55 |
+
"""
|
56 |
+
# Playlist Generator 📻 🎵
|
57 |
+
""")
|
58 |
+
|
59 |
+
with gr.Row():
|
60 |
+
with gr.Column():
|
61 |
+
gr.Markdown(
|
62 |
+
"""
|
63 |
+
Enter a prompt and generate a playlist based on ✨semantic similarity✨
|
64 |
+
This was built using Sentence Transformers and Gradio – [read more here!](#)
|
65 |
+
""")
|
66 |
+
|
67 |
+
song_prompt = gr.TextArea(
|
68 |
+
value="Running wild and free",
|
69 |
+
placeholder="Enter a song prompt, or choose an example"
|
70 |
+
)
|
71 |
+
example_prompts = gr.Dataset(
|
72 |
+
components=[song_prompt],
|
73 |
+
samples=[
|
74 |
+
["I feel nostalgic for the past"],
|
75 |
+
["Running wild and free"],
|
76 |
+
["I'm deeply in love with someone I just met!"],
|
77 |
+
["My friends mean the world to me"],
|
78 |
+
["Sometimes I feel like no one understands"],
|
79 |
+
]
|
80 |
+
)
|
81 |
+
|
82 |
+
with gr.Column():
|
83 |
+
fetch_songs = gr.Button(value="Generate Your Playlist 🧑🏽🎤").style(full_width=True)
|
84 |
+
|
85 |
+
with gr.Row():
|
86 |
+
tile1 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
|
87 |
+
tile2 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
|
88 |
+
tile3 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
|
89 |
+
with gr.Row():
|
90 |
+
tile4 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
|
91 |
+
tile5 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
|
92 |
+
tile6 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
|
93 |
+
with gr.Row():
|
94 |
+
tile7 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
|
95 |
+
tile8 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
|
96 |
+
tile9 = gr.Image(value="https://i.imgur.com/bgCDfT1.jpg", show_label=False, visible=True)
|
97 |
+
|
98 |
+
# Workaround because of the Gallery issues
|
99 |
+
tiles = [tile1, tile2, tile3, tile4, tile5, tile6, tile7, tile8, tile9]
|
100 |
+
|
101 |
+
song_option = gr.Radio(label="Songs", interactive=True, choices=None, type="value")
|
102 |
+
|
103 |
+
with gr.Column():
|
104 |
+
verse = gr.Textbox(label="Verse", placeholder="Select a song to see its lyrics")
|
105 |
+
|
106 |
+
fetch_songs.click(
|
107 |
+
fn=generate_playlist,
|
108 |
+
inputs=[song_prompt],
|
109 |
+
outputs=[song_option, *tiles],
|
110 |
+
)
|
111 |
+
|
112 |
+
example_prompts.click(
|
113 |
+
fn=set_example_prompt,
|
114 |
+
inputs=example_prompts,
|
115 |
+
outputs=example_prompts.components,
|
116 |
+
)
|
117 |
+
|
118 |
+
song_option.change(
|
119 |
+
fn=set_lyrics,
|
120 |
+
inputs=[song_option],
|
121 |
+
outputs=[verse]
|
122 |
+
)
|
123 |
+
|
124 |
+
demo.launch()
|