Bils commited on
Commit
b2b11fc
·
verified ·
1 Parent(s): b3da884

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -223
app.py CHANGED
@@ -1,9 +1,7 @@
1
  import os
2
  import re
3
- import logging
4
  import torch
5
  import tempfile
6
- from typing import Tuple, Union
7
  from scipy.io.wavfile import write
8
  from pydub import AudioSegment
9
  from dotenv import load_dotenv
@@ -21,149 +19,102 @@ from transformers import (
21
  # Coqui TTS
22
  from TTS.api import TTS
23
 
24
- # Kokoro TTS (ensure these are installed)
25
- # pip install -q kokoro>=0.8.2 soundfile
26
- # apt-get -qq -y install espeak-ng > /dev/null 2>&1
27
- from kokoro import KPipeline
28
- import soundfile as sf
29
-
30
  # ---------------------------------------------------------------------
31
- # Configuration & Logging Setup
32
  # ---------------------------------------------------------------------
33
  load_dotenv()
34
  HF_TOKEN = os.getenv("HF_TOKEN")
35
- if not HF_TOKEN:
36
- logging.warning("HF_TOKEN environment variable not set!")
37
-
38
- logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
39
 
 
40
  # Global Model Caches
 
41
  LLAMA_PIPELINES = {}
42
  MUSICGEN_MODELS = {}
43
  TTS_MODELS = {}
44
 
45
  # ---------------------------------------------------------------------
46
- # Utility Functions
47
  # ---------------------------------------------------------------------
48
  def clean_text(text: str) -> str:
49
  """
50
- Clean text by removing undesired characters.
51
-
52
- Args:
53
- text (str): Input text to be cleaned.
54
-
55
- Returns:
56
- str: Cleaned text.
57
  """
58
- # Remove all asterisks. Additional cleaning rules can be added.
59
  return re.sub(r'\*', '', text)
60
 
61
  # ---------------------------------------------------------------------
62
- # Model Loading Helper Functions
63
  # ---------------------------------------------------------------------
64
- def get_llama_pipeline(model_id: str, token: str) -> pipeline:
65
  """
66
- Load and cache the LLaMA text-generation pipeline.
67
-
68
- Args:
69
- model_id (str): Hugging Face model identifier.
70
- token (str): Hugging Face authentication token.
71
-
72
- Returns:
73
- pipeline: Text-generation pipeline instance.
74
  """
75
  if model_id in LLAMA_PIPELINES:
76
  return LLAMA_PIPELINES[model_id]
77
 
78
- try:
79
- tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
80
- model = AutoModelForCausalLM.from_pretrained(
81
- model_id,
82
- use_auth_token=token,
83
- torch_dtype=torch.float16,
84
- device_map="auto",
85
- trust_remote_code=True,
86
- )
87
- text_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
88
- LLAMA_PIPELINES[model_id] = text_pipeline
89
- return text_pipeline
90
- except Exception as e:
91
- logging.error(f"Error loading LLaMA pipeline: {e}")
92
- raise
93
 
94
- def get_musicgen_model(model_key: str = "facebook/musicgen-large") -> Tuple[MusicgenForConditionalGeneration, AutoProcessor]:
 
95
  """
96
- Load and cache the MusicGen model and its processor.
97
-
98
- Args:
99
- model_key (str): Model key (default uses 'facebook/musicgen-large').
100
-
101
- Returns:
102
- tuple: (MusicGen model, processor)
103
  """
104
  if model_key in MUSICGEN_MODELS:
105
  return MUSICGEN_MODELS[model_key]
106
 
107
- try:
108
- model = MusicgenForConditionalGeneration.from_pretrained(model_key)
109
- processor = AutoProcessor.from_pretrained(model_key)
110
- device = "cuda" if torch.cuda.is_available() else "cpu"
111
- model.to(device)
112
- MUSICGEN_MODELS[model_key] = (model, processor)
113
- return model, processor
114
- except Exception as e:
115
- logging.error(f"Error loading MusicGen model: {e}")
116
- raise
117
 
118
- def get_tts_model(model_name: str = "tts_models/en/ljspeech/tacotron2-DDC") -> TTS:
 
119
  """
120
- Load and cache the TTS model.
121
-
122
- Args:
123
- model_name (str): Name of the TTS model.
124
-
125
- Returns:
126
- TTS: TTS model instance.
127
  """
128
  if model_name in TTS_MODELS:
129
  return TTS_MODELS[model_name]
130
 
131
- try:
132
- tts_model = TTS(model_name)
133
- TTS_MODELS[model_name] = tts_model
134
- return tts_model
135
- except Exception as e:
136
- logging.error(f"Error loading TTS model: {e}")
137
- raise
138
 
139
  # ---------------------------------------------------------------------
140
  # Script Generation Function
141
  # ---------------------------------------------------------------------
142
  @spaces.GPU(duration=100)
143
- def generate_script(user_prompt: str, model_id: str, token: str, duration: int) -> Tuple[str, str, str]:
144
  """
145
- Generate a script, sound design suggestions, and music ideas from a user prompt.
146
-
147
- Args:
148
- user_prompt (str): The user's creative input.
149
- model_id (str): Hugging Face model identifier for LLaMA.
150
- token (str): Hugging Face authentication token.
151
- duration (int): Desired duration of the promo in seconds.
152
-
153
- Returns:
154
- tuple: (voice_script, sound_design, music_suggestions)
155
  """
156
  try:
157
  text_pipeline = get_llama_pipeline(model_id, token)
 
158
  system_prompt = (
159
  "You are an expert radio imaging producer specializing in sound design and music. "
160
- f"Based on the user's concept and the selected duration of {duration} seconds, produce the following:\n"
161
- "1. A concise voice-over script. Prefix this section with 'Voice-Over Script:'\n"
162
- "2. Suggestions for sound design. Prefix this section with 'Sound Design Suggestions:'\n"
163
- "3. Music styles or track recommendations. Prefix this section with 'Music Suggestions:'"
164
  )
165
  combined_prompt = f"{system_prompt}\nUser concept: {user_prompt}\nOutput:"
166
-
167
  with torch.inference_mode():
168
  result = text_pipeline(
169
  combined_prompt,
@@ -173,130 +124,86 @@ def generate_script(user_prompt: str, model_id: str, token: str, duration: int)
173
  )
174
 
175
  generated_text = result[0]["generated_text"]
176
- # Remove everything before the 'Output:' marker if present
177
  if "Output:" in generated_text:
178
  generated_text = generated_text.split("Output:")[-1].strip()
179
 
180
- # Initialize default outputs
181
  voice_script = "No voice-over script found."
182
  sound_design = "No sound design suggestions found."
183
  music_suggestions = "No music suggestions found."
184
 
185
- # Parse generated text based on expected prefixes
186
  if "Voice-Over Script:" in generated_text:
187
- voice_section = generated_text.split("Voice-Over Script:")[1]
188
- if "Sound Design Suggestions:" in voice_section:
189
- voice_script = voice_section.split("Sound Design Suggestions:")[0].strip()
 
190
  else:
191
- voice_script = voice_section.strip()
192
 
 
193
  if "Sound Design Suggestions:" in generated_text:
194
- sound_section = generated_text.split("Sound Design Suggestions:")[1]
195
- if "Music Suggestions:" in sound_section:
196
- sound_design = sound_section.split("Music Suggestions:")[0].strip()
 
197
  else:
198
- sound_design = sound_section.strip()
199
 
 
200
  if "Music Suggestions:" in generated_text:
201
- music_suggestions = generated_text.split("Music Suggestions:")[-1].strip()
 
202
 
203
  return voice_script, sound_design, music_suggestions
204
 
205
  except Exception as e:
206
- logging.error(f"Error in generate_script: {e}")
207
  return f"Error generating script: {e}", "", ""
208
 
 
209
  # ---------------------------------------------------------------------
210
- # Voice-Over Generation Functions
211
  # ---------------------------------------------------------------------
212
  @spaces.GPU(duration=100)
213
- def generate_voice(script: str, tts_model_name: str = "tts_models/en/ljspeech/tacotron2-DDC") -> Union[str, None]:
214
  """
215
- Generate a voice-over audio file using Coqui TTS from the provided script.
216
-
217
- Args:
218
- script (str): The voice-over script.
219
- tts_model_name (str): TTS model identifier.
220
-
221
- Returns:
222
- str: File path to the generated .wav file or an error message.
223
  """
224
  try:
225
  if not script.strip():
226
- raise ValueError("No script provided.")
 
 
227
  cleaned_script = clean_text(script)
 
228
  tts_model = get_tts_model(tts_model_name)
229
- output_path = os.path.join(tempfile.gettempdir(), "voice_over_coqui.wav")
 
 
230
  tts_model.tts_to_file(text=cleaned_script, file_path=output_path)
231
- logging.info(f"Coqui voice-over generated at {output_path}")
232
  return output_path
233
 
234
  except Exception as e:
235
- logging.error(f"Error in generate_voice (Coqui TTS): {e}")
236
  return f"Error generating voice: {e}"
237
 
238
- @spaces.GPU(duration=100)
239
- def generate_voice_kokoro(script: str, lang_code: str = 'a', voice: str = 'af_heart', speed: float = 1.0) -> Union[str, None]:
240
- """
241
- Generate a voice-over audio file using the Kokoro TTS model.
242
-
243
- Args:
244
- script (str): The text to synthesize.
245
- lang_code (str): Language code ('a' for American English, etc.).
246
- voice (str): Specific voice style.
247
- speed (float): Speech speed.
248
-
249
- Returns:
250
- str: File path to the generated WAV file or an error message.
251
- """
252
- try:
253
- # Initialize the Kokoro pipeline
254
- kp = KPipeline(lang_code=lang_code)
255
- audio_segments = []
256
- generator = kp(script, voice=voice, speed=speed, split_pattern=r'\n+')
257
- for i, (gs, ps, audio) in enumerate(generator):
258
- audio_segments.append(audio)
259
-
260
- # Join audio segments using pydub
261
- combined = AudioSegment.empty()
262
- for seg in audio_segments:
263
- segment = AudioSegment(
264
- seg.tobytes(),
265
- frame_rate=24000,
266
- sample_width=seg.dtype.itemsize,
267
- channels=1
268
- )
269
- combined += segment
270
-
271
- output_path = os.path.join(tempfile.gettempdir(), "voice_over_kokoro.wav")
272
- combined.export(output_path, format="wav")
273
- logging.info(f"Kokoro voice-over generated at {output_path}")
274
- return output_path
275
-
276
- except Exception as e:
277
- logging.error(f"Error in generate_voice_kokoro: {e}")
278
- return f"Error generating Kokoro voice: {e}"
279
 
280
  # ---------------------------------------------------------------------
281
  # Music Generation Function
282
  # ---------------------------------------------------------------------
283
  @spaces.GPU(duration=200)
284
- def generate_music(prompt: str, audio_length: int) -> Union[str, None]:
285
  """
286
- Generate music based on the prompt using MusicGen.
287
-
288
- Args:
289
- prompt (str): Music prompt or style suggestion.
290
- audio_length (int): Length parameter (number of tokens).
291
-
292
- Returns:
293
- str: File path to the generated .wav file or an error message.
294
  """
295
  try:
296
  if not prompt.strip():
297
- raise ValueError("No music suggestion provided.")
 
298
  model_key = "facebook/musicgen-large"
299
  musicgen_model, musicgen_processor = get_musicgen_model(model_key)
 
300
  device = "cuda" if torch.cuda.is_available() else "cpu"
301
  inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt").to(device)
302
 
@@ -305,47 +212,48 @@ def generate_music(prompt: str, audio_length: int) -> Union[str, None]:
305
 
306
  audio_data = outputs[0, 0].cpu().numpy()
307
  normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16")
308
- output_path = os.path.join(tempfile.gettempdir(), "musicgen_generated_music.wav")
 
309
  write(output_path, 44100, normalized_audio)
310
- logging.info(f"Music generated at {output_path}")
311
  return output_path
312
 
313
  except Exception as e:
314
- logging.error(f"Error in generate_music: {e}")
315
  return f"Error generating music: {e}"
316
 
 
317
  # ---------------------------------------------------------------------
318
- # Audio Blending Function
319
  # ---------------------------------------------------------------------
320
  @spaces.GPU(duration=100)
321
- def blend_audio(voice_path: str, music_path: str, ducking: bool, duck_level: int = 10) -> Union[str, None]:
322
  """
323
- Blend voice and music audio files with optional ducking.
324
-
325
- Args:
326
- voice_path (str): File path to the voice audio.
327
- music_path (str): File path to the music audio.
328
- ducking (bool): If True, attenuate music during voice segments.
329
- duck_level (int): Attenuation level in dB.
330
-
331
- Returns:
332
- str: File path to the blended .wav file or an error message.
333
  """
334
  try:
335
- if not (os.path.isfile(voice_path) and os.path.isfile(music_path)):
336
- raise FileNotFoundError("Missing audio files for blending.")
337
 
338
  voice = AudioSegment.from_wav(voice_path)
339
  music = AudioSegment.from_wav(music_path)
340
- voice_duration = len(voice)
341
 
342
- if len(music) < voice_duration:
 
 
 
 
343
  looped_music = AudioSegment.empty()
344
- while len(looped_music) < voice_duration:
345
  looped_music += music
346
  music = looped_music
347
- else:
348
- music = music[:voice_duration]
 
 
349
 
350
  if ducking:
351
  ducked_music = music - duck_level
@@ -355,13 +263,12 @@ def blend_audio(voice_path: str, music_path: str, ducking: bool, duck_level: int
355
 
356
  output_path = os.path.join(tempfile.gettempdir(), "blended_output.wav")
357
  final_audio.export(output_path, format="wav")
358
- logging.info(f"Audio blended at {output_path}")
359
  return output_path
360
 
361
  except Exception as e:
362
- logging.error(f"Error in blend_audio: {e}")
363
  return f"Error blending audio: {e}"
364
 
 
365
  # ---------------------------------------------------------------------
366
  # Gradio Interface with Enhanced UI
367
  # ---------------------------------------------------------------------
@@ -415,7 +322,7 @@ with gr.Blocks(css="""
415
  Welcome to **AI Promo Studio**! This platform leverages state-of-the-art AI models to help you generate:
416
 
417
  - **Script**: Generate a compelling voice-over script with LLaMA.
418
- - **Voice Synthesis**: Create natural-sounding voice-overs using Coqui TTS or Kokoro TTS.
419
  - **Music Production**: Produce custom music tracks with MusicGen.
420
  - **Audio Blending**: Seamlessly blend voice and music with options for ducking.
421
  """)
@@ -448,26 +355,20 @@ with gr.Blocks(css="""
448
  music_suggestion_output = gr.Textbox(label="Music Suggestions", lines=3, interactive=False)
449
 
450
  generate_script_button.click(
451
- fn=lambda prompt, model, dur: generate_script(prompt, model, HF_TOKEN, dur),
452
  inputs=[user_prompt, llama_model_id, duration],
453
  outputs=[script_output, sound_design_output, music_suggestion_output],
454
  )
455
 
456
  # Step 2: Generate Voice
457
  with gr.Tab("🎤 Voice Synthesis"):
458
- gr.Markdown("Generate a natural-sounding voice-over. Choose your TTS engine below:")
459
- voice_engine = gr.Dropdown(
460
- label="TTS Engine",
461
- choices=["Coqui TTS", "Kokoro TTS"],
462
- value="Coqui TTS",
463
- multiselect=False
464
- )
465
  selected_tts_model = gr.Dropdown(
466
- label="TTS Model / Voice Option",
467
  choices=[
468
- "tts_models/en/ljspeech/tacotron2-DDC", # Coqui TTS option
469
- "tts_models/en/ljspeech/vits", # Coqui TTS option
470
- "af_heart" # Kokoro TTS voice option
471
  ],
472
  value="tts_models/en/ljspeech/tacotron2-DDC",
473
  multiselect=False
@@ -475,18 +376,9 @@ with gr.Blocks(css="""
475
  generate_voice_button = gr.Button("Generate Voice-Over", variant="primary")
476
  voice_audio_output = gr.Audio(label="Voice-Over (WAV)", type="filepath")
477
 
478
- def generate_voice_combined(script, engine, model_choice):
479
- if engine == "Coqui TTS":
480
- return generate_voice(script, model_choice)
481
- elif engine == "Kokoro TTS":
482
- # For Kokoro, pass the voice option (e.g., "af_heart") and default language code ('a')
483
- return generate_voice_kokoro(script, lang_code='a', voice=model_choice, speed=1.0)
484
- else:
485
- return "Error: Unknown TTS engine."
486
-
487
  generate_voice_button.click(
488
- fn=generate_voice_combined,
489
- inputs=[script_output, voice_engine, selected_tts_model],
490
  outputs=voice_audio_output,
491
  )
492
 
@@ -505,7 +397,7 @@ with gr.Blocks(css="""
505
  music_output = gr.Audio(label="Generated Music (WAV)", type="filepath")
506
 
507
  generate_music_button.click(
508
- fn=lambda prompt, length: generate_music(prompt, length),
509
  inputs=[music_suggestion_output, audio_length],
510
  outputs=[music_output],
511
  )
 
1
  import os
2
  import re
 
3
  import torch
4
  import tempfile
 
5
  from scipy.io.wavfile import write
6
  from pydub import AudioSegment
7
  from dotenv import load_dotenv
 
19
  # Coqui TTS
20
  from TTS.api import TTS
21
 
 
 
 
 
 
 
22
  # ---------------------------------------------------------------------
23
+ # Load Environment Variables
24
  # ---------------------------------------------------------------------
25
  load_dotenv()
26
  HF_TOKEN = os.getenv("HF_TOKEN")
 
 
 
 
27
 
28
+ # ---------------------------------------------------------------------
29
  # Global Model Caches
30
+ # ---------------------------------------------------------------------
31
  LLAMA_PIPELINES = {}
32
  MUSICGEN_MODELS = {}
33
  TTS_MODELS = {}
34
 
35
  # ---------------------------------------------------------------------
36
+ # Utility Function: Clean Text
37
  # ---------------------------------------------------------------------
38
  def clean_text(text: str) -> str:
39
  """
40
+ Removes undesired characters (e.g., asterisks) that might not be recognized by the model's vocabulary.
 
 
 
 
 
 
41
  """
42
+ # Remove all asterisks. You can add more cleaning steps here as needed.
43
  return re.sub(r'\*', '', text)
44
 
45
  # ---------------------------------------------------------------------
46
+ # Helper Functions
47
  # ---------------------------------------------------------------------
48
+ def get_llama_pipeline(model_id: str, token: str):
49
  """
50
+ Returns a cached LLaMA pipeline if available; otherwise, loads it.
 
 
 
 
 
 
 
51
  """
52
  if model_id in LLAMA_PIPELINES:
53
  return LLAMA_PIPELINES[model_id]
54
 
55
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
56
+ model = AutoModelForCausalLM.from_pretrained(
57
+ model_id,
58
+ use_auth_token=token,
59
+ torch_dtype=torch.float16,
60
+ device_map="auto",
61
+ trust_remote_code=True,
62
+ )
63
+ text_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
64
+ LLAMA_PIPELINES[model_id] = text_pipeline
65
+ return text_pipeline
 
 
 
 
66
 
67
+
68
+ def get_musicgen_model(model_key: str = "facebook/musicgen-large"):
69
  """
70
+ Returns a cached MusicGen model if available; otherwise, loads it.
71
+ Uses the 'large' variant for higher quality outputs.
 
 
 
 
 
72
  """
73
  if model_key in MUSICGEN_MODELS:
74
  return MUSICGEN_MODELS[model_key]
75
 
76
+ model = MusicgenForConditionalGeneration.from_pretrained(model_key)
77
+ processor = AutoProcessor.from_pretrained(model_key)
78
+
79
+ device = "cuda" if torch.cuda.is_available() else "cpu"
80
+ model.to(device)
81
+ MUSICGEN_MODELS[model_key] = (model, processor)
82
+ return model, processor
 
 
 
83
 
84
+
85
+ def get_tts_model(model_name: str = "tts_models/en/ljspeech/tacotron2-DDC"):
86
  """
87
+ Returns a cached TTS model if available; otherwise, loads it.
 
 
 
 
 
 
88
  """
89
  if model_name in TTS_MODELS:
90
  return TTS_MODELS[model_name]
91
 
92
+ tts_model = TTS(model_name)
93
+ TTS_MODELS[model_name] = tts_model
94
+ return tts_model
95
+
 
 
 
96
 
97
  # ---------------------------------------------------------------------
98
  # Script Generation Function
99
  # ---------------------------------------------------------------------
100
  @spaces.GPU(duration=100)
101
+ def generate_script(user_prompt: str, model_id: str, token: str, duration: int):
102
  """
103
+ Generates a script, sound design suggestions, and music ideas from a user prompt.
104
+ Returns a tuple of strings: (voice_script, sound_design, music_suggestions).
 
 
 
 
 
 
 
 
105
  """
106
  try:
107
  text_pipeline = get_llama_pipeline(model_id, token)
108
+
109
  system_prompt = (
110
  "You are an expert radio imaging producer specializing in sound design and music. "
111
+ f"Based on the user's concept and the selected duration of {duration} seconds, produce the following: "
112
+ "1. A concise voice-over script. Prefix this section with 'Voice-Over Script:'.\n"
113
+ "2. Suggestions for sound design. Prefix this section with 'Sound Design Suggestions:'.\n"
114
+ "3. Music styles or track recommendations. Prefix this section with 'Music Suggestions:'."
115
  )
116
  combined_prompt = f"{system_prompt}\nUser concept: {user_prompt}\nOutput:"
117
+
118
  with torch.inference_mode():
119
  result = text_pipeline(
120
  combined_prompt,
 
124
  )
125
 
126
  generated_text = result[0]["generated_text"]
 
127
  if "Output:" in generated_text:
128
  generated_text = generated_text.split("Output:")[-1].strip()
129
 
130
+ # Default placeholders
131
  voice_script = "No voice-over script found."
132
  sound_design = "No sound design suggestions found."
133
  music_suggestions = "No music suggestions found."
134
 
135
+ # Voice-Over Script
136
  if "Voice-Over Script:" in generated_text:
137
+ parts = generated_text.split("Voice-Over Script:")
138
+ voice_script_part = parts[1]
139
+ if "Sound Design Suggestions:" in voice_script_part:
140
+ voice_script = voice_script_part.split("Sound Design Suggestions:")[0].strip()
141
  else:
142
+ voice_script = voice_script_part.strip()
143
 
144
+ # Sound Design
145
  if "Sound Design Suggestions:" in generated_text:
146
+ parts = generated_text.split("Sound Design Suggestions:")
147
+ sound_design_part = parts[1]
148
+ if "Music Suggestions:" in sound_design_part:
149
+ sound_design = sound_design_part.split("Music Suggestions:")[0].strip()
150
  else:
151
+ sound_design = sound_design_part.strip()
152
 
153
+ # Music Suggestions
154
  if "Music Suggestions:" in generated_text:
155
+ parts = generated_text.split("Music Suggestions:")
156
+ music_suggestions = parts[1].strip()
157
 
158
  return voice_script, sound_design, music_suggestions
159
 
160
  except Exception as e:
 
161
  return f"Error generating script: {e}", "", ""
162
 
163
+
164
  # ---------------------------------------------------------------------
165
+ # Voice-Over Generation Function
166
  # ---------------------------------------------------------------------
167
  @spaces.GPU(duration=100)
168
+ def generate_voice(script: str, tts_model_name: str = "tts_models/en/ljspeech/tacotron2-DDC"):
169
  """
170
+ Generates a voice-over from the provided script using the Coqui TTS model.
171
+ Returns the file path to the generated .wav file.
 
 
 
 
 
 
172
  """
173
  try:
174
  if not script.strip():
175
+ return "Error: No script provided."
176
+
177
+ # Clean the script to remove special characters (e.g., asterisks) that may produce warnings
178
  cleaned_script = clean_text(script)
179
+
180
  tts_model = get_tts_model(tts_model_name)
181
+
182
+ # Generate and save voice
183
+ output_path = os.path.join(tempfile.gettempdir(), "voice_over.wav")
184
  tts_model.tts_to_file(text=cleaned_script, file_path=output_path)
 
185
  return output_path
186
 
187
  except Exception as e:
 
188
  return f"Error generating voice: {e}"
189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
  # ---------------------------------------------------------------------
192
  # Music Generation Function
193
  # ---------------------------------------------------------------------
194
  @spaces.GPU(duration=200)
195
+ def generate_music(prompt: str, audio_length: int):
196
  """
197
+ Generates music from the 'facebook/musicgen-large' model based on the prompt.
198
+ Returns the file path to the generated .wav file.
 
 
 
 
 
 
199
  """
200
  try:
201
  if not prompt.strip():
202
+ return "Error: No music suggestion provided."
203
+
204
  model_key = "facebook/musicgen-large"
205
  musicgen_model, musicgen_processor = get_musicgen_model(model_key)
206
+
207
  device = "cuda" if torch.cuda.is_available() else "cpu"
208
  inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt").to(device)
209
 
 
212
 
213
  audio_data = outputs[0, 0].cpu().numpy()
214
  normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16")
215
+
216
+ output_path = os.path.join(tempfile.gettempdir(), "musicgen_large_generated_music.wav")
217
  write(output_path, 44100, normalized_audio)
218
+
219
  return output_path
220
 
221
  except Exception as e:
 
222
  return f"Error generating music: {e}"
223
 
224
+
225
  # ---------------------------------------------------------------------
226
+ # Audio Blending with Duration Sync & Ducking
227
  # ---------------------------------------------------------------------
228
  @spaces.GPU(duration=100)
229
+ def blend_audio(voice_path: str, music_path: str, ducking: bool, duck_level: int = 10):
230
  """
231
+ Blends two audio files (voice and music).
232
+ 1. If music < voice, loops the music until it meets/exceeds the voice duration.
233
+ 2. If music > voice, trims music to the voice duration.
234
+ 3. If ducking=True, the music is attenuated by 'duck_level' dB while the voice is playing.
235
+ Returns the file path to the blended .wav file.
 
 
 
 
 
236
  """
237
  try:
238
+ if not os.path.isfile(voice_path) or not os.path.isfile(music_path):
239
+ return "Error: Missing audio files for blending."
240
 
241
  voice = AudioSegment.from_wav(voice_path)
242
  music = AudioSegment.from_wav(music_path)
 
243
 
244
+ voice_len = len(voice) # in milliseconds
245
+ music_len = len(music) # in milliseconds
246
+
247
+ # Loop music if it's shorter than the voice
248
+ if music_len < voice_len:
249
  looped_music = AudioSegment.empty()
250
+ while len(looped_music) < voice_len:
251
  looped_music += music
252
  music = looped_music
253
+
254
+ # Trim music if it's longer than the voice
255
+ if len(music) > voice_len:
256
+ music = music[:voice_len]
257
 
258
  if ducking:
259
  ducked_music = music - duck_level
 
263
 
264
  output_path = os.path.join(tempfile.gettempdir(), "blended_output.wav")
265
  final_audio.export(output_path, format="wav")
 
266
  return output_path
267
 
268
  except Exception as e:
 
269
  return f"Error blending audio: {e}"
270
 
271
+
272
  # ---------------------------------------------------------------------
273
  # Gradio Interface with Enhanced UI
274
  # ---------------------------------------------------------------------
 
322
  Welcome to **AI Promo Studio**! This platform leverages state-of-the-art AI models to help you generate:
323
 
324
  - **Script**: Generate a compelling voice-over script with LLaMA.
325
+ - **Voice Synthesis**: Create natural-sounding voice-overs using Coqui TTS.
326
  - **Music Production**: Produce custom music tracks with MusicGen.
327
  - **Audio Blending**: Seamlessly blend voice and music with options for ducking.
328
  """)
 
355
  music_suggestion_output = gr.Textbox(label="Music Suggestions", lines=3, interactive=False)
356
 
357
  generate_script_button.click(
358
+ fn=lambda user_prompt, model_id, dur: generate_script(user_prompt, model_id, HF_TOKEN, dur),
359
  inputs=[user_prompt, llama_model_id, duration],
360
  outputs=[script_output, sound_design_output, music_suggestion_output],
361
  )
362
 
363
  # Step 2: Generate Voice
364
  with gr.Tab("🎤 Voice Synthesis"):
365
+ gr.Markdown("Generate a natural-sounding voice-over using Coqui TTS.")
 
 
 
 
 
 
366
  selected_tts_model = gr.Dropdown(
367
+ label="TTS Model",
368
  choices=[
369
+ "tts_models/en/ljspeech/tacotron2-DDC",
370
+ "tts_models/en/ljspeech/vits",
371
+ "tts_models/en/sam/tacotron-DDC",
372
  ],
373
  value="tts_models/en/ljspeech/tacotron2-DDC",
374
  multiselect=False
 
376
  generate_voice_button = gr.Button("Generate Voice-Over", variant="primary")
377
  voice_audio_output = gr.Audio(label="Voice-Over (WAV)", type="filepath")
378
 
 
 
 
 
 
 
 
 
 
379
  generate_voice_button.click(
380
+ fn=lambda script, tts_model: generate_voice(script, tts_model),
381
+ inputs=[script_output, selected_tts_model],
382
  outputs=voice_audio_output,
383
  )
384
 
 
397
  music_output = gr.Audio(label="Generated Music (WAV)", type="filepath")
398
 
399
  generate_music_button.click(
400
+ fn=lambda music_suggestion, length: generate_music(music_suggestion, length),
401
  inputs=[music_suggestion_output, audio_length],
402
  outputs=[music_output],
403
  )