File size: 7,388 Bytes
17d10a7
a15d204
d448add
1c1b50f
db46bfb
1c1b50f
 
db46bfb
1c1b50f
 
db46bfb
cf3593c
 
 
1c1b50f
c243adb
dfa5d3e
cf3593c
 
 
8b6a33e
e7b189b
 
 
 
1c1b50f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfa5d3e
e7b189b
dfa5d3e
1c1b50f
f0b5707
e7b189b
 
 
8b6a33e
e7b189b
8b6a33e
1c1b50f
8b6a33e
e7b189b
8b6a33e
e7b189b
8b6a33e
 
e7b189b
 
dfa5d3e
 
 
8b6a33e
dfa5d3e
 
 
 
 
 
8b6a33e
dfa5d3e
 
 
 
8b6a33e
 
 
1c1b50f
8b6a33e
 
 
 
 
 
 
 
 
 
 
 
 
dfa5d3e
 
 
1c1b50f
e7b189b
8b6a33e
 
 
 
 
17d10a7
8b6a33e
 
 
 
cf3593c
8b6a33e
17d10a7
 
d448add
cf3593c
 
 
 
 
d448add
dfa5d3e
 
 
53f90b7
8b6a33e
 
 
53f90b7
dfa5d3e
 
8b6a33e
53f90b7
dfa5d3e
53f90b7
8b6a33e
 
53f90b7
dfa5d3e
 
 
 
17d10a7
f0b5707
b50e3e1
70d35c8
8b6a33e
1c1b50f
8b6a33e
 
53f90b7
 
 
 
 
 
 
1a0bb5e
53f90b7
 
 
 
 
 
 
 
 
 
 
3fe530b
1a0bb5e
 
 
a8c9cb5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import gradio as gr
import os
import torch
import time
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    pipeline,
    AutoProcessor,
    MusicgenForConditionalGeneration,
)
from scipy.io.wavfile import write
import tempfile
from dotenv import load_dotenv
import spaces  # Hugging Face Spaces library for ZeroGPU support

# Load environment variables (e.g., Hugging Face token)
load_dotenv()
hf_token = os.getenv("HF_TOKEN")

# Globals for lazy loading
llama_pipeline = None
musicgen_model = None
musicgen_processor = None

# ---------------------------------------------------------------------
# Helper: Safe Model Loader with Retry Logic
# ---------------------------------------------------------------------
def safe_load_model(model_id, token, retries=3, delay=5):
    for attempt in range(retries):
        try:
            model = AutoModelForCausalLM.from_pretrained(
                model_id,
                use_auth_token=token,
                torch_dtype=torch.float16,
                device_map="auto",
                trust_remote_code=True,
                offload_folder="/tmp",  # Stream shards
                cache_dir="/tmp"        # Cache directory for shard downloads
            )
            return model
        except Exception as e:
            print(f"Attempt {attempt + 1} failed: {e}")
            time.sleep(delay)
    raise RuntimeError(f"Failed to load model {model_id} after {retries} attempts")

# ---------------------------------------------------------------------
# Load Llama 3 Model with Zero GPU (Lazy Loading)
# ---------------------------------------------------------------------
@spaces.GPU(duration=600)  # Increased duration to handle large models
def load_llama_pipeline_zero_gpu(model_id: str, token: str):
    global llama_pipeline
    if llama_pipeline is None:
        try:
            print("Starting model loading...")
            tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
            print("Tokenizer loaded.")
            model = safe_load_model(model_id, token)
            print("Model loaded. Initializing pipeline...")
            llama_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
            print("Pipeline initialized successfully.")
        except Exception as e:
            print(f"Error loading Llama pipeline: {e}")
            return str(e)
    return llama_pipeline

# ---------------------------------------------------------------------
# Generate Radio Script
# ---------------------------------------------------------------------
def generate_script(user_input: str, pipeline_llama):
    try:
        system_prompt = (
            "You are a top-tier radio imaging producer using Llama 3. "
            "Take the user's concept and craft a short, creative promo script."
        )
        combined_prompt = f"{system_prompt}\nUser concept: {user_input}\nRefined script:"
        result = pipeline_llama(combined_prompt, max_new_tokens=200, do_sample=True, temperature=0.9)
        return result[0]['generated_text'].split("Refined script:")[-1].strip()
    except Exception as e:
        return f"Error generating script: {e}"

# ---------------------------------------------------------------------
# Load MusicGen Model (Lazy Loading)
# ---------------------------------------------------------------------
@spaces.GPU(duration=600)
def load_musicgen_model():
    global musicgen_model, musicgen_processor
    if musicgen_model is None or musicgen_processor is None:
        try:
            print("Loading MusicGen model...")
            musicgen_model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
            musicgen_processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
            print("MusicGen model loaded successfully.")
        except Exception as e:
            print(f"Error loading MusicGen model: {e}")
            return None, str(e)
    return musicgen_model, musicgen_processor

# ---------------------------------------------------------------------
# Generate Audio
# ---------------------------------------------------------------------
@spaces.GPU(duration=600)
def generate_audio(prompt: str, audio_length: int):
    global musicgen_model, musicgen_processor
    if musicgen_model is None or musicgen_processor is None:
        musicgen_model, musicgen_processor = load_musicgen_model()
        if isinstance(musicgen_model, str):
            return musicgen_model
    try:
        musicgen_model.to("cuda")  # Move the model to GPU
        inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt")
        outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)
        musicgen_model.to("cpu")  # Return the model to CPU

        sr = musicgen_model.config.audio_encoder.sampling_rate
        audio_data = outputs[0, 0].cpu().numpy()
        normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16")

        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_wav:
            write(temp_wav.name, sr, normalized_audio)
            return temp_wav.name
    except Exception as e:
        return f"Error generating audio: {e}"

# ---------------------------------------------------------------------
# Gradio Interface
# ---------------------------------------------------------------------
def generate_script_interface(user_prompt, llama_model_id):
    # Load Llama 3 Pipeline with Zero GPU
    pipeline_llama = load_llama_pipeline_zero_gpu(llama_model_id, hf_token)
    if isinstance(pipeline_llama, str):
        return pipeline_llama

    # Generate Script
    script = generate_script(user_prompt, pipeline_llama)
    return script

def generate_audio_interface(script, audio_length):
    # Generate Audio
    audio_data = generate_audio(script, audio_length)
    return audio_data

# ---------------------------------------------------------------------
# Interface
# ---------------------------------------------------------------------
with gr.Blocks() as demo:
    gr.Markdown("# 🎧 AI Radio Imaging with Llama 3 + MusicGen (Zero GPU)")

    with gr.Row():
        user_prompt = gr.Textbox(label="Enter your promo idea", placeholder="E.g., A 15-second hype jingle for a morning talk show.")
        llama_model_id = gr.Textbox(label="Llama 3 Model ID", value="meta-llama/Meta-Llama-3-8B")  # Using a smaller model for better compatibility
        audio_length = gr.Slider(label="Audio Length (tokens)", minimum=128, maximum=1024, step=64, value=512)

    with gr.Row():
        generate_script_button = gr.Button("Generate Promo Script")
        script_output = gr.Textbox(label="Generated Script", interactive=False)

    with gr.Row():
        generate_audio_button = gr.Button("Generate Audio")
        audio_output = gr.Audio(label="Generated Audio", type="filepath")

    generate_script_button.click(
        generate_script_interface, 
        inputs=[user_prompt, llama_model_id], 
        outputs=script_output
    )

    generate_audio_button.click(
        generate_audio_interface, 
        inputs=[script_output, audio_length], 
        outputs=audio_output
    )

# ---------------------------------------------------------------------
# Launch App
# ---------------------------------------------------------------------
demo.launch(debug=True)