Bils commited on
Commit
98ff03c
·
verified ·
1 Parent(s): 93b1697

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +223 -115
app.py CHANGED
@@ -1,7 +1,9 @@
1
  import os
2
  import re
 
3
  import torch
4
  import tempfile
 
5
  from scipy.io.wavfile import write
6
  from pydub import AudioSegment
7
  from dotenv import load_dotenv
@@ -19,102 +21,149 @@ from transformers import (
19
  # Coqui TTS
20
  from TTS.api import TTS
21
 
 
 
 
 
 
 
22
  # ---------------------------------------------------------------------
23
- # Load Environment Variables
24
  # ---------------------------------------------------------------------
25
  load_dotenv()
26
  HF_TOKEN = os.getenv("HF_TOKEN")
 
 
 
 
27
 
28
- # ---------------------------------------------------------------------
29
  # Global Model Caches
30
- # ---------------------------------------------------------------------
31
  LLAMA_PIPELINES = {}
32
  MUSICGEN_MODELS = {}
33
  TTS_MODELS = {}
34
 
35
  # ---------------------------------------------------------------------
36
- # Utility Function: Clean Text
37
  # ---------------------------------------------------------------------
38
  def clean_text(text: str) -> str:
39
  """
40
- Removes undesired characters (e.g., asterisks) that might not be recognized by the model's vocabulary.
 
 
 
 
 
 
41
  """
42
- # Remove all asterisks. You can add more cleaning steps here as needed.
43
  return re.sub(r'\*', '', text)
44
 
45
  # ---------------------------------------------------------------------
46
- # Helper Functions
47
  # ---------------------------------------------------------------------
48
- def get_llama_pipeline(model_id: str, token: str):
49
  """
50
- Returns a cached LLaMA pipeline if available; otherwise, loads it.
 
 
 
 
 
 
 
51
  """
52
  if model_id in LLAMA_PIPELINES:
53
  return LLAMA_PIPELINES[model_id]
54
 
55
- tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
56
- model = AutoModelForCausalLM.from_pretrained(
57
- model_id,
58
- use_auth_token=token,
59
- torch_dtype=torch.float16,
60
- device_map="auto",
61
- trust_remote_code=True,
62
- )
63
- text_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
64
- LLAMA_PIPELINES[model_id] = text_pipeline
65
- return text_pipeline
66
-
 
 
 
67
 
68
- def get_musicgen_model(model_key: str = "facebook/musicgen-large"):
69
  """
70
- Returns a cached MusicGen model if available; otherwise, loads it.
71
- Uses the 'large' variant for higher quality outputs.
 
 
 
 
 
72
  """
73
  if model_key in MUSICGEN_MODELS:
74
  return MUSICGEN_MODELS[model_key]
75
 
76
- model = MusicgenForConditionalGeneration.from_pretrained(model_key)
77
- processor = AutoProcessor.from_pretrained(model_key)
78
-
79
- device = "cuda" if torch.cuda.is_available() else "cpu"
80
- model.to(device)
81
- MUSICGEN_MODELS[model_key] = (model, processor)
82
- return model, processor
83
-
 
 
84
 
85
- def get_tts_model(model_name: str = "tts_models/en/ljspeech/tacotron2-DDC"):
86
  """
87
- Returns a cached TTS model if available; otherwise, loads it.
 
 
 
 
 
 
88
  """
89
  if model_name in TTS_MODELS:
90
  return TTS_MODELS[model_name]
91
 
92
- tts_model = TTS(model_name)
93
- TTS_MODELS[model_name] = tts_model
94
- return tts_model
95
-
 
 
 
96
 
97
  # ---------------------------------------------------------------------
98
  # Script Generation Function
99
  # ---------------------------------------------------------------------
100
  @spaces.GPU(duration=100)
101
- def generate_script(user_prompt: str, model_id: str, token: str, duration: int):
102
  """
103
- Generates a script, sound design suggestions, and music ideas from a user prompt.
104
- Returns a tuple of strings: (voice_script, sound_design, music_suggestions).
 
 
 
 
 
 
 
 
105
  """
106
  try:
107
  text_pipeline = get_llama_pipeline(model_id, token)
108
-
109
  system_prompt = (
110
  "You are an expert radio imaging producer specializing in sound design and music. "
111
- f"Based on the user's concept and the selected duration of {duration} seconds, produce the following: "
112
- "1. A concise voice-over script. Prefix this section with 'Voice-Over Script:'.\n"
113
- "2. Suggestions for sound design. Prefix this section with 'Sound Design Suggestions:'.\n"
114
- "3. Music styles or track recommendations. Prefix this section with 'Music Suggestions:'."
115
  )
116
  combined_prompt = f"{system_prompt}\nUser concept: {user_prompt}\nOutput:"
117
-
118
  with torch.inference_mode():
119
  result = text_pipeline(
120
  combined_prompt,
@@ -124,86 +173,130 @@ def generate_script(user_prompt: str, model_id: str, token: str, duration: int):
124
  )
125
 
126
  generated_text = result[0]["generated_text"]
 
127
  if "Output:" in generated_text:
128
  generated_text = generated_text.split("Output:")[-1].strip()
129
 
130
- # Default placeholders
131
  voice_script = "No voice-over script found."
132
  sound_design = "No sound design suggestions found."
133
  music_suggestions = "No music suggestions found."
134
 
135
- # Voice-Over Script
136
  if "Voice-Over Script:" in generated_text:
137
- parts = generated_text.split("Voice-Over Script:")
138
- voice_script_part = parts[1]
139
- if "Sound Design Suggestions:" in voice_script_part:
140
- voice_script = voice_script_part.split("Sound Design Suggestions:")[0].strip()
141
  else:
142
- voice_script = voice_script_part.strip()
143
 
144
- # Sound Design
145
  if "Sound Design Suggestions:" in generated_text:
146
- parts = generated_text.split("Sound Design Suggestions:")
147
- sound_design_part = parts[1]
148
- if "Music Suggestions:" in sound_design_part:
149
- sound_design = sound_design_part.split("Music Suggestions:")[0].strip()
150
  else:
151
- sound_design = sound_design_part.strip()
152
 
153
- # Music Suggestions
154
  if "Music Suggestions:" in generated_text:
155
- parts = generated_text.split("Music Suggestions:")
156
- music_suggestions = parts[1].strip()
157
 
158
  return voice_script, sound_design, music_suggestions
159
 
160
  except Exception as e:
 
161
  return f"Error generating script: {e}", "", ""
162
 
163
-
164
  # ---------------------------------------------------------------------
165
- # Voice-Over Generation Function
166
  # ---------------------------------------------------------------------
167
  @spaces.GPU(duration=100)
168
- def generate_voice(script: str, tts_model_name: str = "tts_models/en/ljspeech/tacotron2-DDC"):
169
  """
170
- Generates a voice-over from the provided script using the Coqui TTS model.
171
- Returns the file path to the generated .wav file.
 
 
 
 
 
 
172
  """
173
  try:
174
  if not script.strip():
175
- return "Error: No script provided."
176
-
177
- # Clean the script to remove special characters (e.g., asterisks) that may produce warnings
178
  cleaned_script = clean_text(script)
179
-
180
  tts_model = get_tts_model(tts_model_name)
181
-
182
- # Generate and save voice
183
- output_path = os.path.join(tempfile.gettempdir(), "voice_over.wav")
184
  tts_model.tts_to_file(text=cleaned_script, file_path=output_path)
 
185
  return output_path
186
 
187
  except Exception as e:
 
188
  return f"Error generating voice: {e}"
189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
  # ---------------------------------------------------------------------
192
  # Music Generation Function
193
  # ---------------------------------------------------------------------
194
  @spaces.GPU(duration=200)
195
- def generate_music(prompt: str, audio_length: int):
196
  """
197
- Generates music from the 'facebook/musicgen-large' model based on the prompt.
198
- Returns the file path to the generated .wav file.
 
 
 
 
 
 
199
  """
200
  try:
201
  if not prompt.strip():
202
- return "Error: No music suggestion provided."
203
-
204
  model_key = "facebook/musicgen-large"
205
  musicgen_model, musicgen_processor = get_musicgen_model(model_key)
206
-
207
  device = "cuda" if torch.cuda.is_available() else "cpu"
208
  inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt").to(device)
209
 
@@ -212,48 +305,47 @@ def generate_music(prompt: str, audio_length: int):
212
 
213
  audio_data = outputs[0, 0].cpu().numpy()
214
  normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16")
215
-
216
- output_path = os.path.join(tempfile.gettempdir(), "musicgen_large_generated_music.wav")
217
  write(output_path, 44100, normalized_audio)
218
-
219
  return output_path
220
 
221
  except Exception as e:
 
222
  return f"Error generating music: {e}"
223
 
224
-
225
  # ---------------------------------------------------------------------
226
- # Audio Blending with Duration Sync & Ducking
227
  # ---------------------------------------------------------------------
228
  @spaces.GPU(duration=100)
229
- def blend_audio(voice_path: str, music_path: str, ducking: bool, duck_level: int = 10):
230
  """
231
- Blends two audio files (voice and music).
232
- 1. If music < voice, loops the music until it meets/exceeds the voice duration.
233
- 2. If music > voice, trims music to the voice duration.
234
- 3. If ducking=True, the music is attenuated by 'duck_level' dB while the voice is playing.
235
- Returns the file path to the blended .wav file.
 
 
 
 
 
236
  """
237
  try:
238
- if not os.path.isfile(voice_path) or not os.path.isfile(music_path):
239
- return "Error: Missing audio files for blending."
240
 
241
  voice = AudioSegment.from_wav(voice_path)
242
  music = AudioSegment.from_wav(music_path)
 
243
 
244
- voice_len = len(voice) # in milliseconds
245
- music_len = len(music) # in milliseconds
246
-
247
- # Loop music if it's shorter than the voice
248
- if music_len < voice_len:
249
  looped_music = AudioSegment.empty()
250
- while len(looped_music) < voice_len:
251
  looped_music += music
252
  music = looped_music
253
-
254
- # Trim music if it's longer than the voice
255
- if len(music) > voice_len:
256
- music = music[:voice_len]
257
 
258
  if ducking:
259
  ducked_music = music - duck_level
@@ -263,12 +355,13 @@ def blend_audio(voice_path: str, music_path: str, ducking: bool, duck_level: int
263
 
264
  output_path = os.path.join(tempfile.gettempdir(), "blended_output.wav")
265
  final_audio.export(output_path, format="wav")
 
266
  return output_path
267
 
268
  except Exception as e:
 
269
  return f"Error blending audio: {e}"
270
 
271
-
272
  # ---------------------------------------------------------------------
273
  # Gradio Interface with Enhanced UI
274
  # ---------------------------------------------------------------------
@@ -322,7 +415,7 @@ with gr.Blocks(css="""
322
  Welcome to **AI Promo Studio**! This platform leverages state-of-the-art AI models to help you generate:
323
 
324
  - **Script**: Generate a compelling voice-over script with LLaMA.
325
- - **Voice Synthesis**: Create natural-sounding voice-overs using Coqui TTS.
326
  - **Music Production**: Produce custom music tracks with MusicGen.
327
  - **Audio Blending**: Seamlessly blend voice and music with options for ducking.
328
  """)
@@ -355,20 +448,26 @@ with gr.Blocks(css="""
355
  music_suggestion_output = gr.Textbox(label="Music Suggestions", lines=3, interactive=False)
356
 
357
  generate_script_button.click(
358
- fn=lambda user_prompt, model_id, dur: generate_script(user_prompt, model_id, HF_TOKEN, dur),
359
  inputs=[user_prompt, llama_model_id, duration],
360
  outputs=[script_output, sound_design_output, music_suggestion_output],
361
  )
362
 
363
  # Step 2: Generate Voice
364
  with gr.Tab("🎤 Voice Synthesis"):
365
- gr.Markdown("Generate a natural-sounding voice-over using Coqui TTS.")
 
 
 
 
 
 
366
  selected_tts_model = gr.Dropdown(
367
- label="TTS Model",
368
  choices=[
369
- "tts_models/en/ljspeech/tacotron2-DDC",
370
- "tts_models/en/ljspeech/vits",
371
- "tts_models/en/sam/tacotron-DDC",
372
  ],
373
  value="tts_models/en/ljspeech/tacotron2-DDC",
374
  multiselect=False
@@ -376,9 +475,18 @@ with gr.Blocks(css="""
376
  generate_voice_button = gr.Button("Generate Voice-Over", variant="primary")
377
  voice_audio_output = gr.Audio(label="Voice-Over (WAV)", type="filepath")
378
 
 
 
 
 
 
 
 
 
 
379
  generate_voice_button.click(
380
- fn=lambda script, tts_model: generate_voice(script, tts_model),
381
- inputs=[script_output, selected_tts_model],
382
  outputs=voice_audio_output,
383
  )
384
 
@@ -397,7 +505,7 @@ with gr.Blocks(css="""
397
  music_output = gr.Audio(label="Generated Music (WAV)", type="filepath")
398
 
399
  generate_music_button.click(
400
- fn=lambda music_suggestion, length: generate_music(music_suggestion, length),
401
  inputs=[music_suggestion_output, audio_length],
402
  outputs=[music_output],
403
  )
 
1
  import os
2
  import re
3
+ import logging
4
  import torch
5
  import tempfile
6
+ from typing import Tuple, Union
7
  from scipy.io.wavfile import write
8
  from pydub import AudioSegment
9
  from dotenv import load_dotenv
 
21
  # Coqui TTS
22
  from TTS.api import TTS
23
 
24
+ # Kokoro TTS (ensure these are installed)
25
+ # pip install -q kokoro>=0.8.2 soundfile
26
+ # apt-get -qq -y install espeak-ng > /dev/null 2>&1
27
+ from kokoro import KPipeline
28
+ import soundfile as sf
29
+
30
  # ---------------------------------------------------------------------
31
+ # Configuration & Logging Setup
32
  # ---------------------------------------------------------------------
33
  load_dotenv()
34
  HF_TOKEN = os.getenv("HF_TOKEN")
35
+ if not HF_TOKEN:
36
+ logging.warning("HF_TOKEN environment variable not set!")
37
+
38
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
39
 
 
40
  # Global Model Caches
 
41
  LLAMA_PIPELINES = {}
42
  MUSICGEN_MODELS = {}
43
  TTS_MODELS = {}
44
 
45
  # ---------------------------------------------------------------------
46
+ # Utility Functions
47
  # ---------------------------------------------------------------------
48
  def clean_text(text: str) -> str:
49
  """
50
+ Clean text by removing undesired characters.
51
+
52
+ Args:
53
+ text (str): Input text to be cleaned.
54
+
55
+ Returns:
56
+ str: Cleaned text.
57
  """
58
+ # Remove all asterisks. Additional cleaning rules can be added.
59
  return re.sub(r'\*', '', text)
60
 
61
  # ---------------------------------------------------------------------
62
+ # Model Loading Helper Functions
63
  # ---------------------------------------------------------------------
64
+ def get_llama_pipeline(model_id: str, token: str) -> pipeline:
65
  """
66
+ Load and cache the LLaMA text-generation pipeline.
67
+
68
+ Args:
69
+ model_id (str): Hugging Face model identifier.
70
+ token (str): Hugging Face authentication token.
71
+
72
+ Returns:
73
+ pipeline: Text-generation pipeline instance.
74
  """
75
  if model_id in LLAMA_PIPELINES:
76
  return LLAMA_PIPELINES[model_id]
77
 
78
+ try:
79
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
80
+ model = AutoModelForCausalLM.from_pretrained(
81
+ model_id,
82
+ use_auth_token=token,
83
+ torch_dtype=torch.float16,
84
+ device_map="auto",
85
+ trust_remote_code=True,
86
+ )
87
+ text_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
88
+ LLAMA_PIPELINES[model_id] = text_pipeline
89
+ return text_pipeline
90
+ except Exception as e:
91
+ logging.error(f"Error loading LLaMA pipeline: {e}")
92
+ raise
93
 
94
+ def get_musicgen_model(model_key: str = "facebook/musicgen-large") -> Tuple[MusicgenForConditionalGeneration, AutoProcessor]:
95
  """
96
+ Load and cache the MusicGen model and its processor.
97
+
98
+ Args:
99
+ model_key (str): Model key (default uses 'facebook/musicgen-large').
100
+
101
+ Returns:
102
+ tuple: (MusicGen model, processor)
103
  """
104
  if model_key in MUSICGEN_MODELS:
105
  return MUSICGEN_MODELS[model_key]
106
 
107
+ try:
108
+ model = MusicgenForConditionalGeneration.from_pretrained(model_key)
109
+ processor = AutoProcessor.from_pretrained(model_key)
110
+ device = "cuda" if torch.cuda.is_available() else "cpu"
111
+ model.to(device)
112
+ MUSICGEN_MODELS[model_key] = (model, processor)
113
+ return model, processor
114
+ except Exception as e:
115
+ logging.error(f"Error loading MusicGen model: {e}")
116
+ raise
117
 
118
+ def get_tts_model(model_name: str = "tts_models/en/ljspeech/tacotron2-DDC") -> TTS:
119
  """
120
+ Load and cache the TTS model.
121
+
122
+ Args:
123
+ model_name (str): Name of the TTS model.
124
+
125
+ Returns:
126
+ TTS: TTS model instance.
127
  """
128
  if model_name in TTS_MODELS:
129
  return TTS_MODELS[model_name]
130
 
131
+ try:
132
+ tts_model = TTS(model_name)
133
+ TTS_MODELS[model_name] = tts_model
134
+ return tts_model
135
+ except Exception as e:
136
+ logging.error(f"Error loading TTS model: {e}")
137
+ raise
138
 
139
  # ---------------------------------------------------------------------
140
  # Script Generation Function
141
  # ---------------------------------------------------------------------
142
  @spaces.GPU(duration=100)
143
+ def generate_script(user_prompt: str, model_id: str, token: str, duration: int) -> Tuple[str, str, str]:
144
  """
145
+ Generate a script, sound design suggestions, and music ideas from a user prompt.
146
+
147
+ Args:
148
+ user_prompt (str): The user's creative input.
149
+ model_id (str): Hugging Face model identifier for LLaMA.
150
+ token (str): Hugging Face authentication token.
151
+ duration (int): Desired duration of the promo in seconds.
152
+
153
+ Returns:
154
+ tuple: (voice_script, sound_design, music_suggestions)
155
  """
156
  try:
157
  text_pipeline = get_llama_pipeline(model_id, token)
 
158
  system_prompt = (
159
  "You are an expert radio imaging producer specializing in sound design and music. "
160
+ f"Based on the user's concept and the selected duration of {duration} seconds, produce the following:\n"
161
+ "1. A concise voice-over script. Prefix this section with 'Voice-Over Script:'\n"
162
+ "2. Suggestions for sound design. Prefix this section with 'Sound Design Suggestions:'\n"
163
+ "3. Music styles or track recommendations. Prefix this section with 'Music Suggestions:'"
164
  )
165
  combined_prompt = f"{system_prompt}\nUser concept: {user_prompt}\nOutput:"
166
+
167
  with torch.inference_mode():
168
  result = text_pipeline(
169
  combined_prompt,
 
173
  )
174
 
175
  generated_text = result[0]["generated_text"]
176
+ # Remove everything before the 'Output:' marker if present
177
  if "Output:" in generated_text:
178
  generated_text = generated_text.split("Output:")[-1].strip()
179
 
180
+ # Initialize default outputs
181
  voice_script = "No voice-over script found."
182
  sound_design = "No sound design suggestions found."
183
  music_suggestions = "No music suggestions found."
184
 
185
+ # Parse generated text based on expected prefixes
186
  if "Voice-Over Script:" in generated_text:
187
+ voice_section = generated_text.split("Voice-Over Script:")[1]
188
+ if "Sound Design Suggestions:" in voice_section:
189
+ voice_script = voice_section.split("Sound Design Suggestions:")[0].strip()
 
190
  else:
191
+ voice_script = voice_section.strip()
192
 
 
193
  if "Sound Design Suggestions:" in generated_text:
194
+ sound_section = generated_text.split("Sound Design Suggestions:")[1]
195
+ if "Music Suggestions:" in sound_section:
196
+ sound_design = sound_section.split("Music Suggestions:")[0].strip()
 
197
  else:
198
+ sound_design = sound_section.strip()
199
 
 
200
  if "Music Suggestions:" in generated_text:
201
+ music_suggestions = generated_text.split("Music Suggestions:")[-1].strip()
 
202
 
203
  return voice_script, sound_design, music_suggestions
204
 
205
  except Exception as e:
206
+ logging.error(f"Error in generate_script: {e}")
207
  return f"Error generating script: {e}", "", ""
208
 
 
209
  # ---------------------------------------------------------------------
210
+ # Voice-Over Generation Functions
211
  # ---------------------------------------------------------------------
212
  @spaces.GPU(duration=100)
213
+ def generate_voice(script: str, tts_model_name: str = "tts_models/en/ljspeech/tacotron2-DDC") -> Union[str, None]:
214
  """
215
+ Generate a voice-over audio file using Coqui TTS from the provided script.
216
+
217
+ Args:
218
+ script (str): The voice-over script.
219
+ tts_model_name (str): TTS model identifier.
220
+
221
+ Returns:
222
+ str: File path to the generated .wav file or an error message.
223
  """
224
  try:
225
  if not script.strip():
226
+ raise ValueError("No script provided.")
 
 
227
  cleaned_script = clean_text(script)
 
228
  tts_model = get_tts_model(tts_model_name)
229
+ output_path = os.path.join(tempfile.gettempdir(), "voice_over_coqui.wav")
 
 
230
  tts_model.tts_to_file(text=cleaned_script, file_path=output_path)
231
+ logging.info(f"Coqui voice-over generated at {output_path}")
232
  return output_path
233
 
234
  except Exception as e:
235
+ logging.error(f"Error in generate_voice (Coqui TTS): {e}")
236
  return f"Error generating voice: {e}"
237
 
238
+ @spaces.GPU(duration=100)
239
+ def generate_voice_kokoro(script: str, lang_code: str = 'a', voice: str = 'af_heart', speed: float = 1.0) -> Union[str, None]:
240
+ """
241
+ Generate a voice-over audio file using the Kokoro TTS model.
242
+
243
+ Args:
244
+ script (str): The text to synthesize.
245
+ lang_code (str): Language code ('a' for American English, etc.).
246
+ voice (str): Specific voice style.
247
+ speed (float): Speech speed.
248
+
249
+ Returns:
250
+ str: File path to the generated WAV file or an error message.
251
+ """
252
+ try:
253
+ # Initialize the Kokoro pipeline
254
+ kp = KPipeline(lang_code=lang_code)
255
+ audio_segments = []
256
+ generator = kp(script, voice=voice, speed=speed, split_pattern=r'\n+')
257
+ for i, (gs, ps, audio) in enumerate(generator):
258
+ audio_segments.append(audio)
259
+
260
+ # Join audio segments using pydub
261
+ combined = AudioSegment.empty()
262
+ for seg in audio_segments:
263
+ segment = AudioSegment(
264
+ seg.tobytes(),
265
+ frame_rate=24000,
266
+ sample_width=seg.dtype.itemsize,
267
+ channels=1
268
+ )
269
+ combined += segment
270
+
271
+ output_path = os.path.join(tempfile.gettempdir(), "voice_over_kokoro.wav")
272
+ combined.export(output_path, format="wav")
273
+ logging.info(f"Kokoro voice-over generated at {output_path}")
274
+ return output_path
275
+
276
+ except Exception as e:
277
+ logging.error(f"Error in generate_voice_kokoro: {e}")
278
+ return f"Error generating Kokoro voice: {e}"
279
 
280
  # ---------------------------------------------------------------------
281
  # Music Generation Function
282
  # ---------------------------------------------------------------------
283
  @spaces.GPU(duration=200)
284
+ def generate_music(prompt: str, audio_length: int) -> Union[str, None]:
285
  """
286
+ Generate music based on the prompt using MusicGen.
287
+
288
+ Args:
289
+ prompt (str): Music prompt or style suggestion.
290
+ audio_length (int): Length parameter (number of tokens).
291
+
292
+ Returns:
293
+ str: File path to the generated .wav file or an error message.
294
  """
295
  try:
296
  if not prompt.strip():
297
+ raise ValueError("No music suggestion provided.")
 
298
  model_key = "facebook/musicgen-large"
299
  musicgen_model, musicgen_processor = get_musicgen_model(model_key)
 
300
  device = "cuda" if torch.cuda.is_available() else "cpu"
301
  inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt").to(device)
302
 
 
305
 
306
  audio_data = outputs[0, 0].cpu().numpy()
307
  normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16")
308
+ output_path = os.path.join(tempfile.gettempdir(), "musicgen_generated_music.wav")
 
309
  write(output_path, 44100, normalized_audio)
310
+ logging.info(f"Music generated at {output_path}")
311
  return output_path
312
 
313
  except Exception as e:
314
+ logging.error(f"Error in generate_music: {e}")
315
  return f"Error generating music: {e}"
316
 
 
317
  # ---------------------------------------------------------------------
318
+ # Audio Blending Function
319
  # ---------------------------------------------------------------------
320
  @spaces.GPU(duration=100)
321
+ def blend_audio(voice_path: str, music_path: str, ducking: bool, duck_level: int = 10) -> Union[str, None]:
322
  """
323
+ Blend voice and music audio files with optional ducking.
324
+
325
+ Args:
326
+ voice_path (str): File path to the voice audio.
327
+ music_path (str): File path to the music audio.
328
+ ducking (bool): If True, attenuate music during voice segments.
329
+ duck_level (int): Attenuation level in dB.
330
+
331
+ Returns:
332
+ str: File path to the blended .wav file or an error message.
333
  """
334
  try:
335
+ if not (os.path.isfile(voice_path) and os.path.isfile(music_path)):
336
+ raise FileNotFoundError("Missing audio files for blending.")
337
 
338
  voice = AudioSegment.from_wav(voice_path)
339
  music = AudioSegment.from_wav(music_path)
340
+ voice_duration = len(voice)
341
 
342
+ if len(music) < voice_duration:
 
 
 
 
343
  looped_music = AudioSegment.empty()
344
+ while len(looped_music) < voice_duration:
345
  looped_music += music
346
  music = looped_music
347
+ else:
348
+ music = music[:voice_duration]
 
 
349
 
350
  if ducking:
351
  ducked_music = music - duck_level
 
355
 
356
  output_path = os.path.join(tempfile.gettempdir(), "blended_output.wav")
357
  final_audio.export(output_path, format="wav")
358
+ logging.info(f"Audio blended at {output_path}")
359
  return output_path
360
 
361
  except Exception as e:
362
+ logging.error(f"Error in blend_audio: {e}")
363
  return f"Error blending audio: {e}"
364
 
 
365
  # ---------------------------------------------------------------------
366
  # Gradio Interface with Enhanced UI
367
  # ---------------------------------------------------------------------
 
415
  Welcome to **AI Promo Studio**! This platform leverages state-of-the-art AI models to help you generate:
416
 
417
  - **Script**: Generate a compelling voice-over script with LLaMA.
418
+ - **Voice Synthesis**: Create natural-sounding voice-overs using Coqui TTS or Kokoro TTS.
419
  - **Music Production**: Produce custom music tracks with MusicGen.
420
  - **Audio Blending**: Seamlessly blend voice and music with options for ducking.
421
  """)
 
448
  music_suggestion_output = gr.Textbox(label="Music Suggestions", lines=3, interactive=False)
449
 
450
  generate_script_button.click(
451
+ fn=lambda prompt, model, dur: generate_script(prompt, model, HF_TOKEN, dur),
452
  inputs=[user_prompt, llama_model_id, duration],
453
  outputs=[script_output, sound_design_output, music_suggestion_output],
454
  )
455
 
456
  # Step 2: Generate Voice
457
  with gr.Tab("🎤 Voice Synthesis"):
458
+ gr.Markdown("Generate a natural-sounding voice-over. Choose your TTS engine below:")
459
+ voice_engine = gr.Dropdown(
460
+ label="TTS Engine",
461
+ choices=["Coqui TTS", "Kokoro TTS"],
462
+ value="Coqui TTS",
463
+ multiselect=False
464
+ )
465
  selected_tts_model = gr.Dropdown(
466
+ label="TTS Model / Voice Option",
467
  choices=[
468
+ "tts_models/en/ljspeech/tacotron2-DDC", # Coqui TTS option
469
+ "tts_models/en/ljspeech/vits", # Coqui TTS option
470
+ "af_heart" # Kokoro TTS voice option
471
  ],
472
  value="tts_models/en/ljspeech/tacotron2-DDC",
473
  multiselect=False
 
475
  generate_voice_button = gr.Button("Generate Voice-Over", variant="primary")
476
  voice_audio_output = gr.Audio(label="Voice-Over (WAV)", type="filepath")
477
 
478
+ def generate_voice_combined(script, engine, model_choice):
479
+ if engine == "Coqui TTS":
480
+ return generate_voice(script, model_choice)
481
+ elif engine == "Kokoro TTS":
482
+ # For Kokoro, pass the voice option (e.g., "af_heart") and default language code ('a')
483
+ return generate_voice_kokoro(script, lang_code='a', voice=model_choice, speed=1.0)
484
+ else:
485
+ return "Error: Unknown TTS engine."
486
+
487
  generate_voice_button.click(
488
+ fn=generate_voice_combined,
489
+ inputs=[script_output, voice_engine, selected_tts_model],
490
  outputs=voice_audio_output,
491
  )
492
 
 
505
  music_output = gr.Audio(label="Generated Music (WAV)", type="filepath")
506
 
507
  generate_music_button.click(
508
+ fn=lambda prompt, length: generate_music(prompt, length),
509
  inputs=[music_suggestion_output, audio_length],
510
  outputs=[music_output],
511
  )