Bils commited on
Commit
d0384c8
Β·
verified Β·
1 Parent(s): 9ff3e1b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -129
app.py CHANGED
@@ -10,13 +10,13 @@ from transformers import (
10
  MusicgenForConditionalGeneration
11
  )
12
  from io import BytesIO
13
- from streamlit_lottie import st_lottie # pip install streamlit-lottie
14
 
15
  # ---------------------------------------------------------------------
16
  # 1) PAGE CONFIG
17
  # ---------------------------------------------------------------------
18
  st.set_page_config(
19
- page_title="Radio Imaging AI MVP",
20
  page_icon="🎧",
21
  layout="wide"
22
  )
@@ -26,26 +26,24 @@ st.set_page_config(
26
  # ---------------------------------------------------------------------
27
  CUSTOM_CSS = """
28
  <style>
29
- /* Body styling for a dark, music-app vibe */
30
  body {
31
  background-color: #121212;
32
  color: #FFFFFF;
33
  font-family: "Helvetica Neue", sans-serif;
34
  }
35
 
36
- /* Main container width */
37
  .block-container {
38
  max-width: 1100px;
39
  padding: 1rem 1.5rem;
40
  }
41
 
42
- /* Headings with a neon-ish green accent */
43
  h1, h2, h3 {
44
- color: #1DB954;
45
  margin-bottom: 0.5rem;
46
  }
47
 
48
- /* Buttons: rounded, bright Spotify-like green on hover */
49
  .stButton>button {
50
  background-color: #1DB954 !important;
51
  color: #FFFFFF !important;
@@ -59,13 +57,12 @@ h1, h2, h3 {
59
  background-color: #1ed760 !important;
60
  }
61
 
62
- /* Sidebar: black background, white text */
63
  .sidebar .sidebar-content {
64
  background-color: #000000;
65
  color: #FFFFFF;
66
  }
67
 
68
- /* Text inputs and text areas */
69
  textarea, input, select {
70
  border-radius: 8px !important;
71
  background-color: #282828 !important;
@@ -73,20 +70,20 @@ textarea, input, select {
73
  border: 1px solid #3e3e3e;
74
  }
75
 
76
- /* Audio player styling */
77
  audio {
78
  width: 100%;
79
  margin-top: 1rem;
80
  }
81
 
82
- /* Lottie container styling */
83
  .lottie-container {
84
  display: flex;
85
  justify-content: center;
86
  margin-bottom: 20px;
87
  }
88
 
89
- /* Footer styling */
90
  .footer-note {
91
  text-align: center;
92
  font-size: 14px;
@@ -94,31 +91,83 @@ audio {
94
  margin-top: 2rem;
95
  }
96
 
97
- /* Hide Streamlit's default branding if desired */
98
  #MainMenu, footer {visibility: hidden;}
99
  </style>
100
  """
101
  st.markdown(CUSTOM_CSS, unsafe_allow_html=True)
102
 
103
  # ---------------------------------------------------------------------
104
- # 3) HELPER: LOAD LOTTIE ANIMATION
105
  # ---------------------------------------------------------------------
106
  @st.cache_data
107
  def load_lottie_url(url: str):
108
- """
109
- Fetch Lottie JSON for animations.
110
- """
111
  r = requests.get(url)
112
  if r.status_code != 200:
113
  return None
114
  return r.json()
115
 
116
- # Example Lottie animation (radio waves / music eq, etc.)
117
  LOTTIE_URL = "https://assets3.lottiefiles.com/temp/lf20_Q6h5zV.json"
118
  lottie_animation = load_lottie_url(LOTTIE_URL)
119
 
120
  # ---------------------------------------------------------------------
121
- # 4) SIDEBAR: "LIBRARY" NAVIGATION (MIMICS SPOTIFY)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  # ---------------------------------------------------------------------
123
  with st.sidebar:
124
  st.header("🎚 Radio Library")
@@ -131,21 +180,19 @@ with st.sidebar:
131
  st.markdown("<br>", unsafe_allow_html=True)
132
 
133
  # ---------------------------------------------------------------------
134
- # 5) HEADER SECTION WITH LOTS OF FLARE
135
  # ---------------------------------------------------------------------
136
  col1, col2 = st.columns([3, 2], gap="large")
137
 
138
  with col1:
139
- st.title("AI Radio Imaging MVP")
140
- st.subheader("Llama-Driven Promo Scripts, MusicGen Audio")
141
 
142
  st.markdown(
143
  """
144
- Create **radio imaging promos** and **jingles** with a minimal but creative MVP.
145
- This app:
146
- - Uses a (hypothetical) [Llama 3] model for **script generation**.
147
- - Uses Meta's [MusicGen](https://github.com/facebookresearch/audiocraft) for **audio**.
148
- - Features a Spotify-like UI & Lottie animations for a modern user experience.
149
  """
150
  )
151
  with col2:
@@ -158,152 +205,86 @@ with col2:
158
  st.markdown("---")
159
 
160
  # ---------------------------------------------------------------------
161
- # 6) PROMPT INPUT & MODEL SELECTION
162
  # ---------------------------------------------------------------------
163
- st.subheader("πŸŽ™ Step 1: Briefly Describe Your Promo Idea")
164
 
165
  prompt = st.text_area(
166
- "E.g. 'A 15-second upbeat jingle with a catchy hook for a Top 40 morning show'",
167
  height=120
168
  )
169
 
170
  col_model, col_device = st.columns(2)
171
  with col_model:
172
  llama_model_id = st.text_input(
173
- "Llama Model (Hugging Face ID)",
174
- value="meta-llama/Llama-3.3-70B-Instruct", # Replace with a real model
175
- help="If non-existent, you'll see errors. Try Llama 2 (e.g. meta-llama/Llama-2-7b-chat-hf)."
176
  )
177
  with col_device:
178
  device_option = st.selectbox(
179
- "Choose Device",
180
  ["auto", "cpu"],
181
- help="For GPU usage, pick 'auto'. CPU can be slow for big models."
182
  )
183
 
184
- # ---------------------------------------------------------------------
185
- # 7) BUTTON: GENERATE RADIO SCRIPT WITH LLAMA
186
- # ---------------------------------------------------------------------
187
  if st.button("πŸ“ Generate Promo Script"):
188
  if not prompt.strip():
189
- st.error("Please enter a radio imaging concept first.")
190
  else:
191
- with st.spinner("Generating script..."):
192
  try:
193
- # Load Llama pipeline
194
- pipeline_llama = load_llama_pipeline(llama_model_id, device_option)
195
- # Generate refined script
196
- refined_text = generate_radio_script(prompt, pipeline_llama)
197
- st.session_state["refined_script"] = refined_text
198
  st.success("Promo script generated!")
199
- st.write(refined_text)
200
  except Exception as e:
201
- st.error(f"Error during Llama generation: {e}")
202
 
203
  st.markdown("---")
204
 
205
  # ---------------------------------------------------------------------
206
- # 8) AUDIO GENERATION: MUSICGEN
207
  # ---------------------------------------------------------------------
208
- st.subheader("🎢 Step 2: Generate Your Radio Audio")
209
 
210
- audio_tokens = st.slider("MusicGen Max Tokens (Track Length)", 128, 1024, 512, 64)
211
 
212
  if st.button("🎧 Create Audio with MusicGen"):
213
- # Check if we have a refined script
214
- if "refined_script" not in st.session_state:
215
- st.error("Please generate a promo script first.")
216
  else:
217
- with st.spinner("Generating audio..."):
218
  try:
219
- # Load MusicGen
220
  mg_model, mg_processor = load_musicgen_model()
221
- descriptive_text = st.session_state["refined_script"]
222
-
223
- # Prepare model input
224
  inputs = mg_processor(
225
- text=[descriptive_text],
226
- return_tensors="pt",
227
- padding=True
228
  )
229
- # Generate audio
230
- audio_values = mg_model.generate(**inputs, max_new_tokens=audio_tokens)
231
  sr = mg_model.config.audio_encoder.sampling_rate
232
-
233
- # Save audio to WAV
234
- out_filename = "radio_imaging_output.wav"
235
- scipy.io.wavfile.write(out_filename, rate=sr, data=audio_values[0,0].numpy())
236
-
237
- st.success("Audio created! Press play to listen:")
238
- st.audio(out_filename)
239
- except Exception as e:
240
- st.error(f"Error generating audio: {e}")
241
-
242
- # ---------------------------------------------------------------------
243
- # 9) HELPER FUNCTIONS
244
- # ---------------------------------------------------------------------
245
- @st.cache_resource
246
- def load_llama_pipeline(model_id: str, device: str):
247
- """
248
- Load the Llama model & pipeline.
249
- """
250
- tokenizer = AutoTokenizer.from_pretrained(model_id)
251
- model = AutoModelForCausalLM.from_pretrained(
252
- model_id,
253
- torch_dtype=torch.float16 if device == "auto" else torch.float32,
254
- device_map=device
255
- )
256
- text_gen_pipeline = pipeline(
257
- "text-generation",
258
- model=model,
259
- tokenizer=tokenizer,
260
- device_map=device
261
- )
262
- return text_gen_pipeline
263
 
264
- def generate_radio_script(user_input: str, pipeline_llama) -> str:
265
- """
266
- Use Llama to refine the user's input into a brief but creative radio imaging script.
267
- """
268
- system_prompt = (
269
- "You are a top-tier radio imaging producer. "
270
- "Take the user's concept and craft a short, high-impact promo script. "
271
- "Include style, tone, and potential CTA if relevant."
272
- )
273
- full_prompt = f"{system_prompt}\nUser concept: {user_input}\nRefined script:"
274
-
275
- output = pipeline_llama(
276
- full_prompt,
277
- max_new_tokens=200,
278
- do_sample=True,
279
- temperature=0.9
280
- )[0]["generated_text"]
281
-
282
- # Attempt to isolate the final script portion
283
- if "Refined script:" in output:
284
- output = output.split("Refined script:", 1)[-1].strip()
285
- output += "\n\n(Generated by Llama in Radio Imaging MVP)"
286
-
287
- return output
288
 
289
- @st.cache_resource
290
- def load_musicgen_model():
291
- """
292
- Load MusicGen (small version).
293
- """
294
- mg_model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
295
- mg_processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
296
- return mg_model, mg_processor
297
 
298
  # ---------------------------------------------------------------------
299
- # 10) FOOTER
300
  # ---------------------------------------------------------------------
301
  st.markdown("---")
302
  st.markdown(
303
  """
304
  <div class="footer-note">
305
- &copy; 2025 Radio Imaging MVP &ndash; Built with Llama & MusicGen. <br>
306
- Inspired by Spotify's UI for a sleek, modern experience.
307
  </div>
308
  """,
309
  unsafe_allow_html=True
 
10
  MusicgenForConditionalGeneration
11
  )
12
  from io import BytesIO
13
+ from streamlit_lottie import st_lottie
14
 
15
  # ---------------------------------------------------------------------
16
  # 1) PAGE CONFIG
17
  # ---------------------------------------------------------------------
18
  st.set_page_config(
19
+ page_title="Radio Imaging AI with Llama 3",
20
  page_icon="🎧",
21
  layout="wide"
22
  )
 
26
  # ---------------------------------------------------------------------
27
  CUSTOM_CSS = """
28
  <style>
29
+ /* Dark background with Spotify-like vibe */
30
  body {
31
  background-color: #121212;
32
  color: #FFFFFF;
33
  font-family: "Helvetica Neue", sans-serif;
34
  }
35
 
 
36
  .block-container {
37
  max-width: 1100px;
38
  padding: 1rem 1.5rem;
39
  }
40
 
 
41
  h1, h2, h3 {
42
+ color: #1DB954;
43
  margin-bottom: 0.5rem;
44
  }
45
 
46
+ /* Rounded, bright green button on hover */
47
  .stButton>button {
48
  background-color: #1DB954 !important;
49
  color: #FFFFFF !important;
 
57
  background-color: #1ed760 !important;
58
  }
59
 
60
+ /* Sidebar: black background */
61
  .sidebar .sidebar-content {
62
  background-color: #000000;
63
  color: #FFFFFF;
64
  }
65
 
 
66
  textarea, input, select {
67
  border-radius: 8px !important;
68
  background-color: #282828 !important;
 
70
  border: 1px solid #3e3e3e;
71
  }
72
 
73
+ /* Audio styling */
74
  audio {
75
  width: 100%;
76
  margin-top: 1rem;
77
  }
78
 
79
+ /* Lottie container */
80
  .lottie-container {
81
  display: flex;
82
  justify-content: center;
83
  margin-bottom: 20px;
84
  }
85
 
86
+ /* Footer */
87
  .footer-note {
88
  text-align: center;
89
  font-size: 14px;
 
91
  margin-top: 2rem;
92
  }
93
 
94
+ /* Hide Streamlit branding if you wish */
95
  #MainMenu, footer {visibility: hidden;}
96
  </style>
97
  """
98
  st.markdown(CUSTOM_CSS, unsafe_allow_html=True)
99
 
100
  # ---------------------------------------------------------------------
101
+ # 3) LOAD LOTTIE ANIMATION
102
  # ---------------------------------------------------------------------
103
  @st.cache_data
104
  def load_lottie_url(url: str):
 
 
 
105
  r = requests.get(url)
106
  if r.status_code != 200:
107
  return None
108
  return r.json()
109
 
 
110
  LOTTIE_URL = "https://assets3.lottiefiles.com/temp/lf20_Q6h5zV.json"
111
  lottie_animation = load_lottie_url(LOTTIE_URL)
112
 
113
  # ---------------------------------------------------------------------
114
+ # 4) LOAD LLAMA 3 (GATED MODEL) - WITH use_auth_token
115
+ # ---------------------------------------------------------------------
116
+ @st.cache_resource
117
+ def load_llama_pipeline(model_id: str, device: str):
118
+ """
119
+ Load the Llama 3 model from Hugging Face.
120
+ Requires huggingface-cli login if model is gated.
121
+ """
122
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=True)
123
+ model = AutoModelForCausalLM.from_pretrained(
124
+ model_id,
125
+ torch_dtype=torch.float16 if device == "auto" else torch.float32,
126
+ device_map=device,
127
+ use_auth_token=True
128
+ )
129
+ text_gen_pipeline = pipeline(
130
+ "text-generation",
131
+ model=model,
132
+ tokenizer=tokenizer,
133
+ device_map=device
134
+ )
135
+ return text_gen_pipeline
136
+
137
+ # ---------------------------------------------------------------------
138
+ # 5) REFINE SCRIPT (LLAMA)
139
+ # ---------------------------------------------------------------------
140
+ def generate_radio_script(user_input: str, pipeline_llama) -> str:
141
+ system_prompt = (
142
+ "You are a top-tier radio imaging producer using Llama 3. "
143
+ "Take the user's concept and craft a short, creative promo script."
144
+ )
145
+ combined_prompt = f"{system_prompt}\nUser concept: {user_input}\nRefined script:"
146
+
147
+ result = pipeline_llama(
148
+ combined_prompt,
149
+ max_new_tokens=200,
150
+ do_sample=True,
151
+ temperature=0.9
152
+ )
153
+ output_text = result[0]["generated_text"]
154
+
155
+ if "Refined script:" in output_text:
156
+ output_text = output_text.split("Refined script:", 1)[-1].strip()
157
+ output_text += "\n\n(Generated by Llama 3 - Radio Imaging)"
158
+ return output_text
159
+
160
+ # ---------------------------------------------------------------------
161
+ # 6) LOAD MUSICGEN
162
+ # ---------------------------------------------------------------------
163
+ @st.cache_resource
164
+ def load_musicgen_model():
165
+ mg_model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
166
+ mg_processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
167
+ return mg_model, mg_processor
168
+
169
+ # ---------------------------------------------------------------------
170
+ # 7) SIDEBAR
171
  # ---------------------------------------------------------------------
172
  with st.sidebar:
173
  st.header("🎚 Radio Library")
 
180
  st.markdown("<br>", unsafe_allow_html=True)
181
 
182
  # ---------------------------------------------------------------------
183
+ # 8) HEADER
184
  # ---------------------------------------------------------------------
185
  col1, col2 = st.columns([3, 2], gap="large")
186
 
187
  with col1:
188
+ st.title("AI Radio Imaging with Llama 3")
189
+ st.subheader("Gated Model + MusicGen Audio")
190
 
191
  st.markdown(
192
  """
193
+ Create **radio imaging promos** and **jingles** with Llama 3 + MusicGen.
194
+ **Note**: You must have access to `"meta-llama/Llama-3-70B-Instruct"` on Hugging Face,
195
+ and be logged in via `huggingface-cli login`.
 
 
196
  """
197
  )
198
  with col2:
 
205
  st.markdown("---")
206
 
207
  # ---------------------------------------------------------------------
208
+ # 9) SCRIPT GENERATION
209
  # ---------------------------------------------------------------------
210
+ st.subheader("πŸŽ™ Step 1: Describe Your Promo Idea")
211
 
212
  prompt = st.text_area(
213
+ "Example: 'A 15-second hype jingle for a morning talk show, fun and energetic.'",
214
  height=120
215
  )
216
 
217
  col_model, col_device = st.columns(2)
218
  with col_model:
219
  llama_model_id = st.text_input(
220
+ "Llama 3 Model ID",
221
+ value="meta-llama/Llama-3-70B-Instruct", # Official ID if you have it
222
+ help="Use the exact name you see on the Hugging Face model page."
223
  )
224
  with col_device:
225
  device_option = st.selectbox(
226
+ "Device (GPU vs CPU)",
227
  ["auto", "cpu"],
228
+ help="If you have GPU, 'auto' tries to use it; CPU might be slow."
229
  )
230
 
 
 
 
231
  if st.button("πŸ“ Generate Promo Script"):
232
  if not prompt.strip():
233
+ st.error("Please type some concept first.")
234
  else:
235
+ with st.spinner("Generating script with Llama 3..."):
236
  try:
237
+ llm_pipeline = load_llama_pipeline(llama_model_id, device_option)
238
+ final_script = generate_radio_script(prompt, llm_pipeline)
239
+ st.session_state["final_script"] = final_script
 
 
240
  st.success("Promo script generated!")
241
+ st.write(final_script)
242
  except Exception as e:
243
+ st.error(f"Llama generation error: {e}")
244
 
245
  st.markdown("---")
246
 
247
  # ---------------------------------------------------------------------
248
+ # 10) AUDIO GENERATION: MUSICGEN
249
  # ---------------------------------------------------------------------
250
+ st.subheader("🎢 Step 2: Generate Audio")
251
 
252
+ audio_length = st.slider("MusicGen Max Tokens (approx track length)", 128, 1024, 512, 64)
253
 
254
  if st.button("🎧 Create Audio with MusicGen"):
255
+ if "final_script" not in st.session_state:
256
+ st.error("No script found. Please generate a script first.")
 
257
  else:
258
+ with st.spinner("Creating audio..."):
259
  try:
 
260
  mg_model, mg_processor = load_musicgen_model()
261
+ text_for_audio = st.session_state["final_script"]
262
+
 
263
  inputs = mg_processor(
264
+ text=[text_for_audio],
265
+ padding=True,
266
+ return_tensors="pt"
267
  )
268
+ audio_values = mg_model.generate(**inputs, max_new_tokens=audio_length)
 
269
  sr = mg_model.config.audio_encoder.sampling_rate
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
+ outfile = "llama3_radio_jingle.wav"
272
+ scipy.io.wavfile.write(outfile, rate=sr, data=audio_values[0, 0].numpy())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
+ st.success("Audio generated! Press play below:")
275
+ st.audio(outfile)
276
+ except Exception as e:
277
+ st.error(f"MusicGen error: {e}")
 
 
 
 
278
 
279
  # ---------------------------------------------------------------------
280
+ # 11) FOOTER
281
  # ---------------------------------------------------------------------
282
  st.markdown("---")
283
  st.markdown(
284
  """
285
  <div class="footer-note">
286
+ Β© 2025 Radio Imaging with Llama 3 – Built using Hugging Face & Streamlit. <br>
287
+ Log in via <code>huggingface-cli</code> and ensure access to <strong>meta-llama/Llama-3-70B-Instruct</strong>.
288
  </div>
289
  """,
290
  unsafe_allow_html=True