Bils commited on
Commit
bcd1e5d
·
verified ·
1 Parent(s): eaf7133

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -45
app.py CHANGED
@@ -39,7 +39,6 @@ def clean_text(text: str) -> str:
39
  """
40
  Removes undesired characters (e.g., asterisks) that might not be recognized by the model's vocabulary.
41
  """
42
- # Remove all asterisks.
43
  return re.sub(r'\*', '', text)
44
 
45
  # ---------------------------------------------------------------------
@@ -75,7 +74,6 @@ def get_musicgen_model(model_key: str = "facebook/musicgen-large"):
75
 
76
  model = MusicgenForConditionalGeneration.from_pretrained(model_key)
77
  processor = AutoProcessor.from_pretrained(model_key)
78
-
79
  device = "cuda" if torch.cuda.is_available() else "cpu"
80
  model.to(device)
81
  MUSICGEN_MODELS[model_key] = (model, processor)
@@ -105,7 +103,6 @@ def generate_script(user_prompt: str, model_id: str, token: str, duration: int):
105
  """
106
  try:
107
  text_pipeline = get_llama_pipeline(model_id, token)
108
-
109
  system_prompt = (
110
  "You are an expert radio imaging producer specializing in sound design and music. "
111
  f"Based on the user's concept and the selected duration of {duration} seconds, produce the following: "
@@ -114,7 +111,6 @@ def generate_script(user_prompt: str, model_id: str, token: str, duration: int):
114
  "3. Music styles or track recommendations. Prefix this section with 'Music Suggestions:'."
115
  )
116
  combined_prompt = f"{system_prompt}\nUser concept: {user_prompt}\nOutput:"
117
-
118
  with torch.inference_mode():
119
  result = text_pipeline(
120
  combined_prompt,
@@ -122,16 +118,13 @@ def generate_script(user_prompt: str, model_id: str, token: str, duration: int):
122
  do_sample=True,
123
  temperature=0.8
124
  )
125
-
126
  generated_text = result[0]["generated_text"]
127
  if "Output:" in generated_text:
128
  generated_text = generated_text.split("Output:")[-1].strip()
129
-
130
  # Default placeholders
131
  voice_script = "No voice-over script found."
132
  sound_design = "No sound design suggestions found."
133
  music_suggestions = "No music suggestions found."
134
-
135
  # Voice-Over Script
136
  if "Voice-Over Script:" in generated_text:
137
  parts = generated_text.split("Voice-Over Script:")
@@ -140,7 +133,6 @@ def generate_script(user_prompt: str, model_id: str, token: str, duration: int):
140
  voice_script = voice_script_part.split("Sound Design Suggestions:")[0].strip()
141
  else:
142
  voice_script = voice_script_part.strip()
143
-
144
  # Sound Design
145
  if "Sound Design Suggestions:" in generated_text:
146
  parts = generated_text.split("Sound Design Suggestions:")
@@ -149,18 +141,47 @@ def generate_script(user_prompt: str, model_id: str, token: str, duration: int):
149
  sound_design = sound_design_part.split("Music Suggestions:")[0].strip()
150
  else:
151
  sound_design = sound_design_part.strip()
152
-
153
  # Music Suggestions
154
  if "Music Suggestions:" in generated_text:
155
  parts = generated_text.split("Music Suggestions:")
156
  music_suggestions = parts[1].strip()
157
-
158
  return voice_script, sound_design, music_suggestions
159
-
160
  except Exception as e:
161
  return f"Error generating script: {e}", "", ""
162
 
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  # ---------------------------------------------------------------------
165
  # Voice-Over Generation Function
166
  # ---------------------------------------------------------------------
@@ -173,17 +194,11 @@ def generate_voice(script: str, tts_model_name: str = "tts_models/en/ljspeech/ta
173
  try:
174
  if not script.strip():
175
  return "Error: No script provided."
176
-
177
- # Clean the script to remove special characters (e.g., asterisks) that may produce warnings
178
  cleaned_script = clean_text(script)
179
-
180
  tts_model = get_tts_model(tts_model_name)
181
-
182
- # Generate and save voice
183
  output_path = os.path.join(tempfile.gettempdir(), "voice_over.wav")
184
  tts_model.tts_to_file(text=cleaned_script, file_path=output_path)
185
  return output_path
186
-
187
  except Exception as e:
188
  return f"Error generating voice: {e}"
189
 
@@ -200,24 +215,17 @@ def generate_music(prompt: str, audio_length: int):
200
  try:
201
  if not prompt.strip():
202
  return "Error: No music suggestion provided."
203
-
204
  model_key = "facebook/musicgen-large"
205
  musicgen_model, musicgen_processor = get_musicgen_model(model_key)
206
-
207
  device = "cuda" if torch.cuda.is_available() else "cpu"
208
  inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt").to(device)
209
-
210
  with torch.inference_mode():
211
  outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)
212
-
213
  audio_data = outputs[0, 0].cpu().numpy()
214
  normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16")
215
-
216
  output_path = os.path.join(tempfile.gettempdir(), "musicgen_large_generated_music.wav")
217
  write(output_path, 44100, normalized_audio)
218
-
219
  return output_path
220
-
221
  except Exception as e:
222
  return f"Error generating music: {e}"
223
 
@@ -229,42 +237,30 @@ def generate_music(prompt: str, audio_length: int):
229
  def blend_audio(voice_path: str, music_path: str, ducking: bool, duck_level: int = 10):
230
  """
231
  Blends two audio files (voice and music).
232
- 1. If music < voice, loops the music until it meets/exceeds the voice duration.
233
- 2. If music > voice, trims music to the voice duration.
234
- 3. If ducking=True, the music is attenuated by 'duck_level' dB while the voice is playing.
235
  Returns the file path to the blended .wav file.
236
  """
237
  try:
238
  if not os.path.isfile(voice_path) or not os.path.isfile(music_path):
239
  return "Error: Missing audio files for blending."
240
-
241
  voice = AudioSegment.from_wav(voice_path)
242
  music = AudioSegment.from_wav(music_path)
243
-
244
- voice_len = len(voice) # in milliseconds
245
- music_len = len(music) # in milliseconds
246
-
247
- # Loop music if it's shorter than the voice
248
  if music_len < voice_len:
249
  looped_music = AudioSegment.empty()
250
  while len(looped_music) < voice_len:
251
  looped_music += music
252
  music = looped_music
253
-
254
- # Trim music if it's longer than the voice
255
  if len(music) > voice_len:
256
  music = music[:voice_len]
257
-
258
  if ducking:
259
  ducked_music = music - duck_level
260
  final_audio = ducked_music.overlay(voice)
261
  else:
262
  final_audio = music.overlay(voice)
263
-
264
  output_path = os.path.join(tempfile.gettempdir(), "blended_output.wav")
265
  final_audio.export(output_path, format="wav")
266
  return output_path
267
-
268
  except Exception as e:
269
  return f"Error blending audio: {e}"
270
 
@@ -321,13 +317,36 @@ with gr.Blocks(css="""
321
  gr.Markdown("""
322
  Welcome to **AI Ads Promo (Demo MVP)**! This platform leverages state-of-the-art AI models to help you generate:
323
 
324
- - **Script**: Generate a compelling voice-over script with LLaMA.
 
325
  - **Voice Synthesis**: Create natural-sounding voice-overs using Coqui TTS.
326
- - **Music Production**: Produce custom music tracks with MusicGen.
327
- - **Audio Blending**: Seamlessly blend voice and music with options for ducking.
328
  """)
329
 
330
  with gr.Tabs():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
  # Step 1: Generate Script
332
  with gr.Tab("📝 Script Generation"):
333
  with gr.Row():
@@ -353,7 +372,6 @@ with gr.Blocks(css="""
353
  script_output = gr.Textbox(label="Generated Voice-Over Script", lines=5, interactive=False)
354
  sound_design_output = gr.Textbox(label="Sound Design Suggestions", lines=3, interactive=False)
355
  music_suggestion_output = gr.Textbox(label="Music Suggestions", lines=3, interactive=False)
356
-
357
  generate_script_button.click(
358
  fn=lambda user_prompt, model_id, dur: generate_script(user_prompt, model_id, HF_TOKEN, dur),
359
  inputs=[user_prompt, llama_model_id, duration],
@@ -375,7 +393,6 @@ with gr.Blocks(css="""
375
  )
376
  generate_voice_button = gr.Button("Generate Voice-Over", variant="primary")
377
  voice_audio_output = gr.Audio(label="Voice-Over (WAV)", type="filepath")
378
-
379
  generate_voice_button.click(
380
  fn=lambda script, tts_model: generate_voice(script, tts_model),
381
  inputs=[script_output, selected_tts_model],
@@ -395,7 +412,6 @@ with gr.Blocks(css="""
395
  )
396
  generate_music_button = gr.Button("Generate Music", variant="primary")
397
  music_output = gr.Audio(label="Generated Music (WAV)", type="filepath")
398
-
399
  generate_music_button.click(
400
  fn=lambda music_suggestion, length: generate_music(music_suggestion, length),
401
  inputs=[music_suggestion_output, audio_length],
@@ -415,7 +431,6 @@ with gr.Blocks(css="""
415
  )
416
  blend_button = gr.Button("Blend Voice + Music", variant="primary")
417
  blended_output = gr.Audio(label="Final Blended Output (WAV)", type="filepath")
418
-
419
  blend_button.click(
420
  fn=blend_audio,
421
  inputs=[voice_audio_output, music_output, ducking_checkbox, duck_level_slider],
 
39
  """
40
  Removes undesired characters (e.g., asterisks) that might not be recognized by the model's vocabulary.
41
  """
 
42
  return re.sub(r'\*', '', text)
43
 
44
  # ---------------------------------------------------------------------
 
74
 
75
  model = MusicgenForConditionalGeneration.from_pretrained(model_key)
76
  processor = AutoProcessor.from_pretrained(model_key)
 
77
  device = "cuda" if torch.cuda.is_available() else "cpu"
78
  model.to(device)
79
  MUSICGEN_MODELS[model_key] = (model, processor)
 
103
  """
104
  try:
105
  text_pipeline = get_llama_pipeline(model_id, token)
 
106
  system_prompt = (
107
  "You are an expert radio imaging producer specializing in sound design and music. "
108
  f"Based on the user's concept and the selected duration of {duration} seconds, produce the following: "
 
111
  "3. Music styles or track recommendations. Prefix this section with 'Music Suggestions:'."
112
  )
113
  combined_prompt = f"{system_prompt}\nUser concept: {user_prompt}\nOutput:"
 
114
  with torch.inference_mode():
115
  result = text_pipeline(
116
  combined_prompt,
 
118
  do_sample=True,
119
  temperature=0.8
120
  )
 
121
  generated_text = result[0]["generated_text"]
122
  if "Output:" in generated_text:
123
  generated_text = generated_text.split("Output:")[-1].strip()
 
124
  # Default placeholders
125
  voice_script = "No voice-over script found."
126
  sound_design = "No sound design suggestions found."
127
  music_suggestions = "No music suggestions found."
 
128
  # Voice-Over Script
129
  if "Voice-Over Script:" in generated_text:
130
  parts = generated_text.split("Voice-Over Script:")
 
133
  voice_script = voice_script_part.split("Sound Design Suggestions:")[0].strip()
134
  else:
135
  voice_script = voice_script_part.strip()
 
136
  # Sound Design
137
  if "Sound Design Suggestions:" in generated_text:
138
  parts = generated_text.split("Sound Design Suggestions:")
 
141
  sound_design = sound_design_part.split("Music Suggestions:")[0].strip()
142
  else:
143
  sound_design = sound_design_part.strip()
 
144
  # Music Suggestions
145
  if "Music Suggestions:" in generated_text:
146
  parts = generated_text.split("Music Suggestions:")
147
  music_suggestions = parts[1].strip()
 
148
  return voice_script, sound_design, music_suggestions
 
149
  except Exception as e:
150
  return f"Error generating script: {e}", "", ""
151
 
152
 
153
+ # ---------------------------------------------------------------------
154
+ # Ad Promo Idea Generation Function
155
+ # ---------------------------------------------------------------------
156
+ @spaces.GPU(duration=100)
157
+ def generate_ad_promo_idea(user_prompt: str, model_id: str, token: str):
158
+ """
159
+ Generates a creative ad promo idea based on the user's concept.
160
+ Returns a string containing the ad promo idea.
161
+ """
162
+ try:
163
+ text_pipeline = get_llama_pipeline(model_id, token)
164
+ system_prompt = (
165
+ "You are a creative advertising strategist. "
166
+ "Generate a unique and engaging ad promo idea based on the following concept. "
167
+ "Include creative angles, potential taglines, and media suggestions."
168
+ )
169
+ combined_prompt = f"{system_prompt}\nConcept: {user_prompt}\nAd Promo Idea:"
170
+ with torch.inference_mode():
171
+ result = text_pipeline(
172
+ combined_prompt,
173
+ max_new_tokens=150,
174
+ do_sample=True,
175
+ temperature=0.8
176
+ )
177
+ generated_text = result[0]["generated_text"]
178
+ if "Ad Promo Idea:" in generated_text:
179
+ generated_text = generated_text.split("Ad Promo Idea:")[-1].strip()
180
+ return generated_text
181
+ except Exception as e:
182
+ return f"Error generating ad promo idea: {e}"
183
+
184
+
185
  # ---------------------------------------------------------------------
186
  # Voice-Over Generation Function
187
  # ---------------------------------------------------------------------
 
194
  try:
195
  if not script.strip():
196
  return "Error: No script provided."
 
 
197
  cleaned_script = clean_text(script)
 
198
  tts_model = get_tts_model(tts_model_name)
 
 
199
  output_path = os.path.join(tempfile.gettempdir(), "voice_over.wav")
200
  tts_model.tts_to_file(text=cleaned_script, file_path=output_path)
201
  return output_path
 
202
  except Exception as e:
203
  return f"Error generating voice: {e}"
204
 
 
215
  try:
216
  if not prompt.strip():
217
  return "Error: No music suggestion provided."
 
218
  model_key = "facebook/musicgen-large"
219
  musicgen_model, musicgen_processor = get_musicgen_model(model_key)
 
220
  device = "cuda" if torch.cuda.is_available() else "cpu"
221
  inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt").to(device)
 
222
  with torch.inference_mode():
223
  outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)
 
224
  audio_data = outputs[0, 0].cpu().numpy()
225
  normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16")
 
226
  output_path = os.path.join(tempfile.gettempdir(), "musicgen_large_generated_music.wav")
227
  write(output_path, 44100, normalized_audio)
 
228
  return output_path
 
229
  except Exception as e:
230
  return f"Error generating music: {e}"
231
 
 
237
  def blend_audio(voice_path: str, music_path: str, ducking: bool, duck_level: int = 10):
238
  """
239
  Blends two audio files (voice and music).
 
 
 
240
  Returns the file path to the blended .wav file.
241
  """
242
  try:
243
  if not os.path.isfile(voice_path) or not os.path.isfile(music_path):
244
  return "Error: Missing audio files for blending."
 
245
  voice = AudioSegment.from_wav(voice_path)
246
  music = AudioSegment.from_wav(music_path)
247
+ voice_len = len(voice)
248
+ music_len = len(music)
 
 
 
249
  if music_len < voice_len:
250
  looped_music = AudioSegment.empty()
251
  while len(looped_music) < voice_len:
252
  looped_music += music
253
  music = looped_music
 
 
254
  if len(music) > voice_len:
255
  music = music[:voice_len]
 
256
  if ducking:
257
  ducked_music = music - duck_level
258
  final_audio = ducked_music.overlay(voice)
259
  else:
260
  final_audio = music.overlay(voice)
 
261
  output_path = os.path.join(tempfile.gettempdir(), "blended_output.wav")
262
  final_audio.export(output_path, format="wav")
263
  return output_path
 
264
  except Exception as e:
265
  return f"Error blending audio: {e}"
266
 
 
317
  gr.Markdown("""
318
  Welcome to **AI Ads Promo (Demo MVP)**! This platform leverages state-of-the-art AI models to help you generate:
319
 
320
+ - **Ad Promo Ideas**: Generate creative ad concepts.
321
+ - **Script**: Produce a compelling voice-over script with LLaMA.
322
  - **Voice Synthesis**: Create natural-sounding voice-overs using Coqui TTS.
323
+ - **Music Production**: Generate custom music tracks with MusicGen.
324
+ - **Audio Blending**: Seamlessly combine voice and music with ducking options.
325
  """)
326
 
327
  with gr.Tabs():
328
+ # New Tab: Generate Ad Promo Idea
329
+ with gr.Tab("💡 Ad Promo Idea"):
330
+ with gr.Row():
331
+ ad_concept = gr.Textbox(
332
+ label="Ad Concept",
333
+ placeholder="Enter your ad concept or idea...",
334
+ lines=2
335
+ )
336
+ with gr.Row():
337
+ llama_model_id_idea = gr.Textbox(
338
+ label="LLaMA Model ID",
339
+ value="meta-llama/Meta-Llama-3-8B-Instruct",
340
+ placeholder="Enter a valid Hugging Face model ID"
341
+ )
342
+ generate_ad_idea_button = gr.Button("Generate Ad Promo Idea", variant="primary")
343
+ ad_idea_output = gr.Textbox(label="Generated Ad Promo Idea", lines=5, interactive=False)
344
+ generate_ad_idea_button.click(
345
+ fn=lambda concept, model_id: generate_ad_promo_idea(concept, model_id, HF_TOKEN),
346
+ inputs=[ad_concept, llama_model_id_idea],
347
+ outputs=ad_idea_output
348
+ )
349
+
350
  # Step 1: Generate Script
351
  with gr.Tab("📝 Script Generation"):
352
  with gr.Row():
 
372
  script_output = gr.Textbox(label="Generated Voice-Over Script", lines=5, interactive=False)
373
  sound_design_output = gr.Textbox(label="Sound Design Suggestions", lines=3, interactive=False)
374
  music_suggestion_output = gr.Textbox(label="Music Suggestions", lines=3, interactive=False)
 
375
  generate_script_button.click(
376
  fn=lambda user_prompt, model_id, dur: generate_script(user_prompt, model_id, HF_TOKEN, dur),
377
  inputs=[user_prompt, llama_model_id, duration],
 
393
  )
394
  generate_voice_button = gr.Button("Generate Voice-Over", variant="primary")
395
  voice_audio_output = gr.Audio(label="Voice-Over (WAV)", type="filepath")
 
396
  generate_voice_button.click(
397
  fn=lambda script, tts_model: generate_voice(script, tts_model),
398
  inputs=[script_output, selected_tts_model],
 
412
  )
413
  generate_music_button = gr.Button("Generate Music", variant="primary")
414
  music_output = gr.Audio(label="Generated Music (WAV)", type="filepath")
 
415
  generate_music_button.click(
416
  fn=lambda music_suggestion, length: generate_music(music_suggestion, length),
417
  inputs=[music_suggestion_output, audio_length],
 
431
  )
432
  blend_button = gr.Button("Blend Voice + Music", variant="primary")
433
  blended_output = gr.Audio(label="Final Blended Output (WAV)", type="filepath")
 
434
  blend_button.click(
435
  fn=blend_audio,
436
  inputs=[voice_audio_output, music_output, ducking_checkbox, duck_level_slider],