Spaces:
Sleeping
Sleeping
File size: 5,209 Bytes
17d10a7 a15d204 d448add db46bfb 17d10a7 c243adb 17d10a7 d0384c8 17d10a7 613bd9e 17d10a7 613bd9e 17d10a7 613bd9e 17d10a7 d0384c8 17d10a7 d0384c8 17d10a7 d0384c8 17d10a7 d0384c8 17d10a7 d0384c8 17d10a7 3fe530b 17d10a7 d448add db46bfb 17d10a7 db46bfb 17d10a7 d448add 17d10a7 d448add 17d10a7 a15d204 17d10a7 c243adb 17d10a7 d448add db46bfb 17d10a7 c243adb 17d10a7 d448add 17d10a7 d448add 17d10a7 3fe530b 17d10a7 db46bfb 17d10a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
import gradio as gr
import os
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
pipeline,
AutoProcessor,
MusicgenForConditionalGeneration
)
import scipy.io.wavfile as wav
# ---------------------------------------------------------------------
# Load Llama 3 Model
# ---------------------------------------------------------------------
def load_llama_pipeline(model_id: str, token: str, device: str = "cpu"):
try:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
model = AutoModelForCausalLM.from_pretrained(
model_id,
use_auth_token=token,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map="auto" if device == "cuda" else None,
low_cpu_mem_usage=True
)
return pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if device == "cuda" else -1)
except Exception as e:
return str(e)
# ---------------------------------------------------------------------
# Generate Radio Script
# ---------------------------------------------------------------------
def generate_script(user_input: str, pipeline_llama):
try:
system_prompt = (
"You are a top-tier radio imaging producer using Llama 3. "
"Take the user's concept and craft a short, creative promo script."
)
combined_prompt = f"{system_prompt}\nUser concept: {user_input}\nRefined script:"
result = pipeline_llama(combined_prompt, max_new_tokens=200, do_sample=True, temperature=0.9)
return result[0]['generated_text'].split("Refined script:")[-1].strip()
except Exception as e:
return f"Error generating script: {e}"
# ---------------------------------------------------------------------
# Load MusicGen Model
# ---------------------------------------------------------------------
def load_musicgen_model():
try:
model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
return model, processor
except Exception as e:
return None, str(e)
# ---------------------------------------------------------------------
# Generate Audio
# ---------------------------------------------------------------------
def generate_audio(prompt: str, audio_length: int, mg_model, mg_processor):
try:
inputs = mg_processor(text=[prompt], padding=True, return_tensors="pt")
outputs = mg_model.generate(**inputs, max_new_tokens=audio_length)
sr = mg_model.config.audio_encoder.sampling_rate
audio_data = outputs[0, 0].cpu().numpy()
normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16")
output_file = "radio_jingle.wav"
wav.write(output_file, rate=sr, data=normalized_audio)
return output_file
except Exception as e:
return str(e)
# ---------------------------------------------------------------------
# Gradio Interface
# ---------------------------------------------------------------------
def radio_imaging_app(user_prompt, llama_model_id, hf_token, audio_length):
# Load Llama 3 Pipeline
pipeline_llama = load_llama_pipeline(llama_model_id, hf_token, device="cuda" if torch.cuda.is_available() else "cpu")
if isinstance(pipeline_llama, str):
return pipeline_llama, None
# Generate Script
script = generate_script(user_prompt, pipeline_llama)
# Load MusicGen
mg_model, mg_processor = load_musicgen_model()
if isinstance(mg_processor, str):
return script, mg_processor
# Generate Audio
audio_file = generate_audio(script, audio_length, mg_model, mg_processor)
if isinstance(audio_file, str) and audio_file.startswith("Error"):
return script, audio_file
return script, audio_file
# ---------------------------------------------------------------------
# Interface
# ---------------------------------------------------------------------
with gr.Blocks() as demo:
gr.Markdown("# 🎧 AI Radio Imaging with Llama 3 + MusicGen")
with gr.Row():
user_prompt = gr.Textbox(label="Enter your promo idea", placeholder="E.g., A 15-second hype jingle for a morning talk show, fun and energetic.")
llama_model_id = gr.Textbox(label="Llama 3 Model ID", value="meta-llama/Meta-Llama-3-70B")
hf_token = gr.Textbox(label="Hugging Face Token", type="password")
audio_length = gr.Slider(label="Audio Length (tokens)", minimum=128, maximum=1024, step=64, value=512)
generate_button = gr.Button("Generate Promo Script and Audio")
script_output = gr.Textbox(label="Generated Script")
audio_output = gr.Audio(label="Generated Audio", type="file")
generate_button.click(radio_imaging_app,
inputs=[user_prompt, llama_model_id, hf_token, audio_length],
outputs=[script_output, audio_output])
# ---------------------------------------------------------------------
# Launch App
# ---------------------------------------------------------------------
demo.launch()
|