Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,388 Bytes
17d10a7 a15d204 d448add 1c1b50f db46bfb 1c1b50f db46bfb 1c1b50f db46bfb cf3593c 1c1b50f c243adb dfa5d3e cf3593c 8b6a33e e7b189b 1c1b50f dfa5d3e e7b189b dfa5d3e 1c1b50f f0b5707 e7b189b 8b6a33e e7b189b 8b6a33e 1c1b50f 8b6a33e e7b189b 8b6a33e e7b189b 8b6a33e e7b189b dfa5d3e 8b6a33e dfa5d3e 8b6a33e dfa5d3e 8b6a33e 1c1b50f 8b6a33e dfa5d3e 1c1b50f e7b189b 8b6a33e 17d10a7 8b6a33e cf3593c 8b6a33e 17d10a7 d448add cf3593c d448add dfa5d3e 53f90b7 8b6a33e 53f90b7 dfa5d3e 8b6a33e 53f90b7 dfa5d3e 53f90b7 8b6a33e 53f90b7 dfa5d3e 17d10a7 f0b5707 b50e3e1 70d35c8 8b6a33e 1c1b50f 8b6a33e 53f90b7 1a0bb5e 53f90b7 3fe530b 1a0bb5e a8c9cb5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import gradio as gr
import os
import torch
import time
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
pipeline,
AutoProcessor,
MusicgenForConditionalGeneration,
)
from scipy.io.wavfile import write
import tempfile
from dotenv import load_dotenv
import spaces # Hugging Face Spaces library for ZeroGPU support
# Load environment variables (e.g., Hugging Face token)
load_dotenv()
hf_token = os.getenv("HF_TOKEN")
# Globals for lazy loading
llama_pipeline = None
musicgen_model = None
musicgen_processor = None
# ---------------------------------------------------------------------
# Helper: Safe Model Loader with Retry Logic
# ---------------------------------------------------------------------
def safe_load_model(model_id, token, retries=3, delay=5):
for attempt in range(retries):
try:
model = AutoModelForCausalLM.from_pretrained(
model_id,
use_auth_token=token,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
offload_folder="/tmp", # Stream shards
cache_dir="/tmp" # Cache directory for shard downloads
)
return model
except Exception as e:
print(f"Attempt {attempt + 1} failed: {e}")
time.sleep(delay)
raise RuntimeError(f"Failed to load model {model_id} after {retries} attempts")
# ---------------------------------------------------------------------
# Load Llama 3 Model with Zero GPU (Lazy Loading)
# ---------------------------------------------------------------------
@spaces.GPU(duration=600) # Increased duration to handle large models
def load_llama_pipeline_zero_gpu(model_id: str, token: str):
global llama_pipeline
if llama_pipeline is None:
try:
print("Starting model loading...")
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
print("Tokenizer loaded.")
model = safe_load_model(model_id, token)
print("Model loaded. Initializing pipeline...")
llama_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
print("Pipeline initialized successfully.")
except Exception as e:
print(f"Error loading Llama pipeline: {e}")
return str(e)
return llama_pipeline
# ---------------------------------------------------------------------
# Generate Radio Script
# ---------------------------------------------------------------------
def generate_script(user_input: str, pipeline_llama):
try:
system_prompt = (
"You are a top-tier radio imaging producer using Llama 3. "
"Take the user's concept and craft a short, creative promo script."
)
combined_prompt = f"{system_prompt}\nUser concept: {user_input}\nRefined script:"
result = pipeline_llama(combined_prompt, max_new_tokens=200, do_sample=True, temperature=0.9)
return result[0]['generated_text'].split("Refined script:")[-1].strip()
except Exception as e:
return f"Error generating script: {e}"
# ---------------------------------------------------------------------
# Load MusicGen Model (Lazy Loading)
# ---------------------------------------------------------------------
@spaces.GPU(duration=600)
def load_musicgen_model():
global musicgen_model, musicgen_processor
if musicgen_model is None or musicgen_processor is None:
try:
print("Loading MusicGen model...")
musicgen_model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
musicgen_processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
print("MusicGen model loaded successfully.")
except Exception as e:
print(f"Error loading MusicGen model: {e}")
return None, str(e)
return musicgen_model, musicgen_processor
# ---------------------------------------------------------------------
# Generate Audio
# ---------------------------------------------------------------------
@spaces.GPU(duration=600)
def generate_audio(prompt: str, audio_length: int):
global musicgen_model, musicgen_processor
if musicgen_model is None or musicgen_processor is None:
musicgen_model, musicgen_processor = load_musicgen_model()
if isinstance(musicgen_model, str):
return musicgen_model
try:
musicgen_model.to("cuda") # Move the model to GPU
inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt")
outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)
musicgen_model.to("cpu") # Return the model to CPU
sr = musicgen_model.config.audio_encoder.sampling_rate
audio_data = outputs[0, 0].cpu().numpy()
normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16")
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_wav:
write(temp_wav.name, sr, normalized_audio)
return temp_wav.name
except Exception as e:
return f"Error generating audio: {e}"
# ---------------------------------------------------------------------
# Gradio Interface
# ---------------------------------------------------------------------
def generate_script_interface(user_prompt, llama_model_id):
# Load Llama 3 Pipeline with Zero GPU
pipeline_llama = load_llama_pipeline_zero_gpu(llama_model_id, hf_token)
if isinstance(pipeline_llama, str):
return pipeline_llama
# Generate Script
script = generate_script(user_prompt, pipeline_llama)
return script
def generate_audio_interface(script, audio_length):
# Generate Audio
audio_data = generate_audio(script, audio_length)
return audio_data
# ---------------------------------------------------------------------
# Interface
# ---------------------------------------------------------------------
with gr.Blocks() as demo:
gr.Markdown("# 🎧 AI Radio Imaging with Llama 3 + MusicGen (Zero GPU)")
with gr.Row():
user_prompt = gr.Textbox(label="Enter your promo idea", placeholder="E.g., A 15-second hype jingle for a morning talk show.")
llama_model_id = gr.Textbox(label="Llama 3 Model ID", value="meta-llama/Meta-Llama-3-8B") # Using a smaller model for better compatibility
audio_length = gr.Slider(label="Audio Length (tokens)", minimum=128, maximum=1024, step=64, value=512)
with gr.Row():
generate_script_button = gr.Button("Generate Promo Script")
script_output = gr.Textbox(label="Generated Script", interactive=False)
with gr.Row():
generate_audio_button = gr.Button("Generate Audio")
audio_output = gr.Audio(label="Generated Audio", type="filepath")
generate_script_button.click(
generate_script_interface,
inputs=[user_prompt, llama_model_id],
outputs=script_output
)
generate_audio_button.click(
generate_audio_interface,
inputs=[script_output, audio_length],
outputs=audio_output
)
# ---------------------------------------------------------------------
# Launch App
# ---------------------------------------------------------------------
demo.launch(debug=True)
|