File size: 5,209 Bytes
17d10a7
a15d204
d448add
db46bfb
 
 
 
 
 
 
17d10a7
c243adb
 
17d10a7
d0384c8
17d10a7
613bd9e
 
 
 
 
17d10a7
 
613bd9e
 
17d10a7
613bd9e
17d10a7
d0384c8
 
17d10a7
d0384c8
17d10a7
 
 
 
 
 
 
 
 
 
 
d0384c8
 
17d10a7
d0384c8
 
17d10a7
 
 
 
 
 
d0384c8
 
17d10a7
3fe530b
17d10a7
 
 
 
 
 
 
 
 
 
 
 
d448add
db46bfb
17d10a7
db46bfb
17d10a7
 
 
 
 
d448add
17d10a7
 
d448add
17d10a7
 
 
 
a15d204
17d10a7
 
 
 
c243adb
17d10a7
d448add
db46bfb
17d10a7
c243adb
17d10a7
 
 
 
 
 
 
d448add
17d10a7
 
 
d448add
17d10a7
 
 
3fe530b
 
17d10a7
db46bfb
17d10a7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import gradio as gr
import os
import torch
from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM, 
    pipeline,
    AutoProcessor, 
    MusicgenForConditionalGeneration
)
import scipy.io.wavfile as wav

# ---------------------------------------------------------------------
# Load Llama 3 Model
# ---------------------------------------------------------------------
def load_llama_pipeline(model_id: str, token: str, device: str = "cpu"):
    try:
        tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            use_auth_token=token,
            torch_dtype=torch.float16 if device == "cuda" else torch.float32,
            device_map="auto" if device == "cuda" else None,
            low_cpu_mem_usage=True
        )
        return pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if device == "cuda" else -1)
    except Exception as e:
        return str(e)

# ---------------------------------------------------------------------
# Generate Radio Script
# ---------------------------------------------------------------------
def generate_script(user_input: str, pipeline_llama):
    try:
        system_prompt = (
            "You are a top-tier radio imaging producer using Llama 3. "
            "Take the user's concept and craft a short, creative promo script."
        )
        combined_prompt = f"{system_prompt}\nUser concept: {user_input}\nRefined script:"
        result = pipeline_llama(combined_prompt, max_new_tokens=200, do_sample=True, temperature=0.9)
        return result[0]['generated_text'].split("Refined script:")[-1].strip()
    except Exception as e:
        return f"Error generating script: {e}"

# ---------------------------------------------------------------------
# Load MusicGen Model
# ---------------------------------------------------------------------
def load_musicgen_model():
    try:
        model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
        processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
        return model, processor
    except Exception as e:
        return None, str(e)

# ---------------------------------------------------------------------
# Generate Audio
# ---------------------------------------------------------------------
def generate_audio(prompt: str, audio_length: int, mg_model, mg_processor):
    try:
        inputs = mg_processor(text=[prompt], padding=True, return_tensors="pt")
        outputs = mg_model.generate(**inputs, max_new_tokens=audio_length)
        sr = mg_model.config.audio_encoder.sampling_rate
        audio_data = outputs[0, 0].cpu().numpy()
        normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16")
        output_file = "radio_jingle.wav"
        wav.write(output_file, rate=sr, data=normalized_audio)
        return output_file
    except Exception as e:
        return str(e)

# ---------------------------------------------------------------------
# Gradio Interface
# ---------------------------------------------------------------------
def radio_imaging_app(user_prompt, llama_model_id, hf_token, audio_length):
    # Load Llama 3 Pipeline
    pipeline_llama = load_llama_pipeline(llama_model_id, hf_token, device="cuda" if torch.cuda.is_available() else "cpu")
    if isinstance(pipeline_llama, str):
        return pipeline_llama, None

    # Generate Script
    script = generate_script(user_prompt, pipeline_llama)

    # Load MusicGen
    mg_model, mg_processor = load_musicgen_model()
    if isinstance(mg_processor, str):
        return script, mg_processor

    # Generate Audio
    audio_file = generate_audio(script, audio_length, mg_model, mg_processor)
    if isinstance(audio_file, str) and audio_file.startswith("Error"):
        return script, audio_file

    return script, audio_file

# ---------------------------------------------------------------------
# Interface
# ---------------------------------------------------------------------
with gr.Blocks() as demo:
    gr.Markdown("# 🎧 AI Radio Imaging with Llama 3 + MusicGen")
    with gr.Row():
        user_prompt = gr.Textbox(label="Enter your promo idea", placeholder="E.g., A 15-second hype jingle for a morning talk show, fun and energetic.")
        llama_model_id = gr.Textbox(label="Llama 3 Model ID", value="meta-llama/Meta-Llama-3-70B")
        hf_token = gr.Textbox(label="Hugging Face Token", type="password")
        audio_length = gr.Slider(label="Audio Length (tokens)", minimum=128, maximum=1024, step=64, value=512)

    generate_button = gr.Button("Generate Promo Script and Audio")
    script_output = gr.Textbox(label="Generated Script")
    audio_output = gr.Audio(label="Generated Audio", type="file")

    generate_button.click(radio_imaging_app, 
                          inputs=[user_prompt, llama_model_id, hf_token, audio_length], 
                          outputs=[script_output, audio_output])

# ---------------------------------------------------------------------
# Launch App
# ---------------------------------------------------------------------
demo.launch()