File size: 12,078 Bytes
17d10a7
a15d204
d448add
db46bfb
1c1b50f
 
db46bfb
1c1b50f
db8ba25
db46bfb
cf3593c
d9bf0f0
b950350
6aba99a
3168a3e
019c404
8e5f278
3168a3e
2de59b3
 
 
cf3593c
2de59b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf3593c
1c1b50f
b950350
1c1b50f
a38649c
b950350
2de59b3
 
 
 
dfa5d3e
2de59b3
db8ba25
fd8d42a
dfa5d3e
2de59b3
fd8d42a
2de59b3
 
 
3168a3e
60b6e41
74b6128
e564c8e
2de59b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd8d42a
2de59b3
 
 
 
 
 
 
 
 
 
 
fd8d42a
 
2de59b3
b950350
74b6128
b950350
2de59b3
b950350
66b1260
b950350
a38649c
019c404
2de59b3
 
 
b950350
66b1260
dfa5d3e
66b1260
dfa5d3e
2de59b3
dfa5d3e
2de59b3
dfa5d3e
a38649c
217c4b5
2de59b3
 
 
 
17d10a7
2de59b3
 
a3b5047
217c4b5
2de59b3
217c4b5
16184b2
2de59b3
 
 
b950350
217c4b5
2de59b3
217c4b5
d9bf0f0
2de59b3
16184b2
217c4b5
1808e7a
217c4b5
16184b2
cf3593c
b950350
d448add
16184b2
dfa5d3e
2de59b3
dfa5d3e
b950350
2de59b3
 
 
ecc69bf
66b1260
d9bf0f0
66b1260
 
d9bf0f0
 
b950350
ecc69bf
3172dc7
ede9fc5
2de59b3
 
 
ede9fc5
ecc69bf
35e8eba
 
 
 
2de59b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35e8eba
 
2de59b3
 
 
35e8eba
 
2de59b3
35e8eba
 
 
 
2de59b3
35e8eba
a07ea84
2de59b3
 
a07ea84
35e8eba
 
 
 
2de59b3
 
 
 
 
 
 
 
35e8eba
2de59b3
35e8eba
 
2de59b3
35e8eba
 
 
 
2de59b3
35e8eba
a07ea84
2de59b3
 
a07ea84
8c25665
2de59b3
b950350
2de59b3
 
 
 
b950350
1d543ba
2de59b3
1d543ba
2de59b3
 
 
1d543ba
3fe530b
2de59b3
35e8eba
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
import gradio as gr
import os
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    pipeline,
    AutoProcessor,
    MusicgenForConditionalGeneration,
)
from scipy.io.wavfile import write
from pydub import AudioSegment
from dotenv import load_dotenv
import tempfile
import spaces
from TTS.api import TTS
from TTS.utils.synthesizer import Synthesizer

# ---------------------------------------------------------------------
# Load Environment Variables
# ---------------------------------------------------------------------
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")

# ---------------------------------------------------------------------
# Global Model Caches
# ---------------------------------------------------------------------
# We store models/pipelines in global variables for reuse,
# so they are only loaded once.
LLAMA_PIPELINES = {}
MUSICGEN_MODELS = {}

# ---------------------------------------------------------------------
# Helper Functions
# ---------------------------------------------------------------------
def get_llama_pipeline(model_id: str, token: str):
    """
    Returns a cached LLaMA pipeline if available; otherwise, loads it.
    This significantly reduces loading time for repeated calls.
    """
    if model_id in LLAMA_PIPELINES:
        return LLAMA_PIPELINES[model_id]

    # Load new pipeline and store in cache
    tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        use_auth_token=token,
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True,
    )
    text_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
    LLAMA_PIPELINES[model_id] = text_pipeline
    return text_pipeline


def get_musicgen_model(model_key: str = "facebook/musicgen-medium"):
    """
    Returns a cached MusicGen model if available; otherwise, loads it.
    """
    if model_key in MUSICGEN_MODELS:
        return MUSICGEN_MODELS[model_key]

    # Load new MusicGen model and store in cache
    model = MusicgenForConditionalGeneration.from_pretrained(model_key)
    processor = AutoProcessor.from_pretrained(model_key)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)

    MUSICGEN_MODELS[model_key] = (model, processor)
    return model, processor


# ---------------------------------------------------------------------
# Script Generation Function
# ---------------------------------------------------------------------
@spaces.GPU(duration=100)
def generate_script(user_prompt: str, model_id: str, token: str, duration: int):
    """
    Generates a script, sound design suggestions, and music ideas from a user prompt.
    Returns a tuple of strings: (voice_script, sound_design, music_suggestions).
    """
    try:
        text_pipeline = get_llama_pipeline(model_id, token)

        # System prompt with clear structure instructions
        system_prompt = (
            "You are an expert radio imaging producer specializing in sound design and music. "
            f"Based on the user's concept and the selected duration of {duration} seconds, produce the following: "
            "1. A concise voice-over script. Prefix this section with 'Voice-Over Script:'.\n"
            "2. Suggestions for sound design. Prefix this section with 'Sound Design Suggestions:'.\n"
            "3. Music styles or track recommendations. Prefix this section with 'Music Suggestions:'."
        )

        combined_prompt = f"{system_prompt}\nUser concept: {user_prompt}\nOutput:"

        # Use inference mode for efficient forward passes
        with torch.inference_mode():
            result = text_pipeline(
                combined_prompt,
                max_new_tokens=300,
                do_sample=True,
                temperature=0.8
            )

        # LLaMA pipeline returns a list of dicts with "generated_text"
        generated_text = result[0]["generated_text"]

        # Basic parsing to isolate everything after "Output:"
        # (in case the model repeated your system prompt).
        if "Output:" in generated_text:
            generated_text = generated_text.split("Output:")[-1].strip()

        # Extract sections based on known prefixes
        voice_script = "No voice-over script found."
        sound_design = "No sound design suggestions found."
        music_suggestions = "No music suggestions found."

        if "Voice-Over Script:" in generated_text:
            parts = generated_text.split("Voice-Over Script:")
            if len(parts) > 1:
                # Everything after "Voice-Over Script:" up until next prefix
                voice_script_part = parts[1]
                voice_script = voice_script_part.split("Sound Design Suggestions:")[0].strip() \
                    if "Sound Design Suggestions:" in voice_script_part else voice_script_part.strip()

        if "Sound Design Suggestions:" in generated_text:
            parts = generated_text.split("Sound Design Suggestions:")
            if len(parts) > 1:
                sound_design_part = parts[1]
                sound_design = sound_design_part.split("Music Suggestions:")[0].strip() \
                    if "Music Suggestions:" in sound_design_part else sound_design_part.strip()

        if "Music Suggestions:" in generated_text:
            parts = generated_text.split("Music Suggestions:")
            if len(parts) > 1:
                music_suggestions = parts[1].strip()

        return voice_script, sound_design, music_suggestions

    except Exception as e:
        return f"Error generating script: {e}", "", ""


# ---------------------------------------------------------------------
# Voice-Over Generation Function (Inactive)
# ---------------------------------------------------------------------
@spaces.GPU(duration=100)
def generate_voice(script: str, speaker: str = "default"):
    """
    Placeholder for future voice-over generation functionality.
    """
    try:
        return "Voice-over generation is currently inactive."
    except Exception as e:
        return f"Error: {e}"


# ---------------------------------------------------------------------
# Music Generation Function
# ---------------------------------------------------------------------
@spaces.GPU(duration=100)
def generate_music(prompt: str, audio_length: int):
    """
    Generates music from the 'facebook/musicgen-medium' model based on the prompt.
    Returns the file path to the generated .wav file.
    """
    try:
        model_key = "facebook/musicgen-medium"
        musicgen_model, musicgen_processor = get_musicgen_model(model_key)

        device = "cuda" if torch.cuda.is_available() else "cpu"
        # Prepare input
        inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt").to(device)

        # Generate music within inference mode
        with torch.inference_mode():
            outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)

        audio_data = outputs[0, 0].cpu().numpy()
        # Normalize audio to int16 format
        normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16")

        # Save generated music to a temp file
        output_path = f"{tempfile.gettempdir()}/musicgen_medium_generated_music.wav"
        write(output_path, 44100, normalized_audio)

        return output_path

    except Exception as e:
        return f"Error generating music: {e}"


# ---------------------------------------------------------------------
# Audio Blending Function (Inactive)
# ---------------------------------------------------------------------
def blend_audio(voice_path: str, music_path: str, ducking: bool):
    """
    Placeholder for future audio blending functionality with optional ducking.
    """
    try:
        return "Audio blending functionality is currently inactive."
    except Exception as e:
        return f"Error: {e}"


# ---------------------------------------------------------------------
# Gradio Interface
# ---------------------------------------------------------------------
with gr.Blocks() as demo:
    gr.Markdown("""
    # 🎧 AI Promo Studio 🚀  
    Welcome to **AI Promo Studio**, your one-stop solution for creating stunning and professional radio promos with ease!  
    Whether you're a sound designer, radio producer, or content creator, our AI-driven tools, powered by advanced LLM Llama models, empower you to bring your vision to life in just a few steps.  
    """)

    with gr.Tabs():
        # Step 1: Generate Script
        with gr.Tab("Step 1: Generate Script"):
            with gr.Row():
                user_prompt = gr.Textbox(
                    label="Promo Idea", 
                    placeholder="E.g., A 30-second promo for a morning show...",
                    lines=2
                )
                llama_model_id = gr.Textbox(
                    label="LLaMA Model ID", 
                    value="meta-llama/Meta-Llama-3-8B-Instruct", 
                    placeholder="Enter a valid Hugging Face model ID"
                )
                duration = gr.Slider(
                    label="Desired Promo Duration (seconds)",
                    minimum=15, 
                    maximum=60, 
                    step=15, 
                    value=30
                )

            generate_script_button = gr.Button("Generate Script")
            script_output = gr.Textbox(label="Generated Voice-Over Script", lines=5, interactive=False)
            sound_design_output = gr.Textbox(label="Sound Design Suggestions", lines=3, interactive=False)
            music_suggestion_output = gr.Textbox(label="Music Suggestions", lines=3, interactive=False)

            generate_script_button.click(
                fn=lambda user_prompt, model_id, dur: generate_script(user_prompt, model_id, HF_TOKEN, dur),
                inputs=[user_prompt, llama_model_id, duration],
                outputs=[script_output, sound_design_output, music_suggestion_output],
            )

        # Step 2: Generate Voice (Inactive)
        with gr.Tab("Step 2: Generate Voice"):
            gr.Markdown("""
            **Note:** Voice-over generation is currently inactive. 
            This feature will be available in future updates!
            """)

        # Step 3: Generate Music
        with gr.Tab("Step 3: Generate Music"):
            with gr.Row():
                audio_length = gr.Slider(
                    label="Music Length (tokens)",
                    minimum=128, 
                    maximum=1024, 
                    step=64, 
                    value=512,
                    info="Increase tokens for longer audio, but be mindful of inference time."
                )
            generate_music_button = gr.Button("Generate Music")
            music_output = gr.Audio(label="Generated Music (WAV)", type="filepath")

            generate_music_button.click(
                fn=lambda music_suggestion, length: generate_music(music_suggestion, length),
                inputs=[music_suggestion_output, audio_length],
                outputs=[music_output],
            )

        # Step 4: Blend Audio (Inactive)
        with gr.Tab("Step 4: Blend Audio"):
            gr.Markdown("""
            **Note:** Audio blending functionality is currently inactive. 
            This feature will be available in future updates!
            """)

    # Footer / Credits
    gr.Markdown("""
    <hr>
    <p style="text-align: center; font-size: 0.9em;">
        Created with ❤️ by <a href="https://bilsimaging.com" target="_blank">bilsimaging.com</a>
    </p>
    """)
    
    # Visitor Badge
    gr.HTML("""
    <a href="https://visitorbadge.io/status?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2Fradiogold">
        <img src="https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2Fradiogold&countColor=%23263759" />
    </a>
    """)

# Launch the Gradio app
demo.launch(debug=True)