AIPromoStudio / app.py
Bils's picture
Update app.py
5607a62 verified
raw
history blame
11.2 kB
import gradio as gr
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
pipeline,
AutoProcessor,
MusicgenForConditionalGeneration,
)
from scipy.io.wavfile import write
from pydub import AudioSegment
from dotenv import load_dotenv
import tempfile
import spaces
from TTS.api import TTS
import psutil
import GPUtil
# -------------------------------
# Configuration
# -------------------------------
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN", os.getenv("HF_TOKEN_SECRET"))
MODEL_CONFIG = {
"llama_models": {
"Meta-Llama-3-8B": "meta-llama/Meta-Llama-3-8B-Instruct",
"Mistral-7B": "mistralai/Mistral-7B-Instruct-v0.2",
},
"tts_models": {
"Standard English": "tts_models/en/ljspeech/tacotron2-DDC",
"High Quality": "tts_models/en/ljspeech/vits"
},
"musicgen_model": "facebook/musicgen-medium"
}
# -------------------------------
# Model Manager with Cache
# -------------------------------
class ModelManager:
def __init__(self):
self.llama_pipelines = {}
self.musicgen_model = None
self.tts_models = {}
self.processor = None # Add processor cache
def get_llama_pipeline(self, model_id, token):
if model_id not in self.llama_pipelines:
tokenizer = AutoTokenizer.from_pretrained(
model_id,
token=token,
legacy=False
)
model = AutoModelForCausalLM.from_pretrained(
model_id,
token=token,
torch_dtype=torch.float16,
device_map="auto",
low_cpu_mem_usage=True
)
self.llama_pipelines[model_id] = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device_map="auto"
)
return self.llama_pipelines[model_id]
def get_musicgen_model(self):
if not self.musicgen_model:
self.musicgen_model = MusicgenForConditionalGeneration.from_pretrained(
MODEL_CONFIG["musicgen_model"]
)
self.processor = AutoProcessor.from_pretrained(MODEL_CONFIG["musicgen_model"])
self.musicgen_model.to("cuda" if torch.cuda.is_available() else "cpu")
return self.musicgen_model, self.processor
model_manager = ModelManager()
# -------------------------------
# Core Functions with Enhanced Error Handling
# -------------------------------
@spaces.GPU
def generate_script(user_prompt, model_id, duration, progress=gr.Progress()):
try:
progress(0.1, "Initializing script generation...")
text_pipeline = model_manager.get_llama_pipeline(model_id, HF_TOKEN)
system_prompt = f"""Generate a {duration}-second radio promo with:
1. Voice Script: [Clear narration, 25-35 words]
2. Sound Design: [3-5 specific sound effects]
3. Music: [Genre, tempo, mood]
Format strictly as:
Voice Script: [content]
Sound Design: [effects]
Music: [description]"""
progress(0.3, "Generating content...")
response = text_pipeline(
f"{system_prompt}\nConcept: {user_prompt}",
max_new_tokens=300,
temperature=0.7,
do_sample=True,
top_p=0.95
)
progress(0.8, "Parsing results...")
return parse_generated_content(response[0]["generated_text"])
except Exception as e:
return [f"Error: {str(e)}"] * 3
def parse_generated_content(text):
sections = {"Voice Script": "", "Sound Design": "", "Music": ""}
current_section = None
for line in text.split('\n'):
line = line.strip()
for section in sections:
if line.startswith(section + ":"):
current_section = section
line = line.replace(section + ":", "").strip()
break
if current_section and line:
sections[current_section] += line + "\n"
return [sections[section].strip() for section in sections]
@spaces.GPU
def generate_voice(script, tts_model, speed=1.0, progress=gr.Progress()):
try:
progress(0.2, "Initializing TTS...")
if not script.strip():
return None, "No script provided"
tts = model_manager.get_tts_model(tts_model)
output_path = os.path.join(tempfile.gettempdir(), "voice.wav")
progress(0.5, "Generating audio...")
tts.tts_to_file(text=script, file_path=output_path, speed=speed)
return output_path, None
except Exception as e:
return None, f"Voice Error: {str(e)}"
@spaces.GPU
def generate_music(prompt, duration_sec=30, progress=gr.Progress()):
try:
progress(0.1, "Initializing MusicGen...")
model = model_manager.get_musicgen_model()
processor = AutoProcessor.from_pretrained(MODEL_CONFIG["musicgen_model"])
progress(0.4, "Processing input...")
inputs = processor(text=[prompt], padding=True, return_tensors="pt").to(model.device)
progress(0.6, "Generating music...")
audio_values = model.generate(**inputs, max_new_tokens=int(duration_sec * 50))
output_path = os.path.join(tempfile.gettempdir(), "music.wav")
write(output_path, 32000, audio_values[0, 0].cpu().numpy())
return output_path, None
except Exception as e:
return None, f"Music Error: {str(e)}"
def blend_audio(voice_path, music_path, ducking=True, progress=gr.Progress()):
try:
progress(0.2, "Loading audio files...")
voice = AudioSegment.from_wav(voice_path)
music = AudioSegment.from_wav(music_path)
progress(0.4, "Aligning durations...")
if len(music) < len(voice):
music = music * (len(voice) // len(music) + 1)
music = music[:len(voice)]
progress(0.6, "Mixing audio...")
if ducking:
music = music - 10 # 10dB ducking
mixed = music.overlay(voice)
output_path = os.path.join(tempfile.gettempdir(), "final_mix.wav")
mixed.export(output_path, format="wav")
return output_path, None
except Exception as e:
return None, f"Mixing Error: {str(e)}"
# -------------------------------
# UI Components
# -------------------------------
def create_audio_visualization(audio_path):
if not audio_path:
return None
audio = AudioSegment.from_file(audio_path)
samples = np.array(audio.get_array_of_samples())
plt.figure(figsize=(10, 3))
plt.plot(samples)
plt.axis('off')
plt.tight_layout()
temp_file = os.path.join(tempfile.gettempdir(), "waveform.png")
plt.savefig(temp_file, bbox_inches='tight', pad_inches=0)
plt.close()
return temp_file
def system_monitor():
gpus = GPUtil.getGPUs()
return {
"CPU": f"{psutil.cpu_percent()}%",
"RAM": f"{psutil.virtual_memory().percent}%",
"GPU": f"{gpus[0].load*100 if gpus else 0:.1f}%" if gpus else "N/A"
}
# -------------------------------
# Gradio Interface
# -------------------------------
theme = gr.themes.Soft(
primary_hue="blue",
secondary_hue="teal",
).set(
body_text_color_dark='#FFFFFF',
background_fill_primary_dark='#1F1F1F'
)
with gr.Blocks(theme=theme, title="AI Radio Studio Pro") as demo:
gr.Markdown("# πŸŽ™οΈ AI Radio Studio Pro")
with gr.Row():
with gr.Column(scale=3):
concept_input = gr.Textbox(
label="Concept Description",
placeholder="Describe your radio segment...",
lines=3
)
with gr.Accordion("Advanced Settings", open=False):
model_selector = gr.Dropdown(
list(MODEL_CONFIG["llama_models"].values()),
label="AI Model",
value=next(iter(MODEL_CONFIG["llama_models"].values()))
)
duration_selector = gr.Slider(15, 120, 30, step=15, label="Duration (seconds)")
generate_btn = gr.Button("Generate Script", variant="primary")
with gr.Column(scale=2):
script_output = gr.Textbox(label="Voice Script", interactive=True)
sound_output = gr.Textbox(label="Sound Design", interactive=True)
music_output = gr.Textbox(label="Music Style", interactive=True)
with gr.Tabs():
with gr.Tab("🎀 Voice Production"):
with gr.Row():
tts_selector = gr.Dropdown(
list(MODEL_CONFIG["tts_models"].values()),
label="Voice Model",
value=next(iter(MODEL_CONFIG["tts_models"].values()))
)
speed_selector = gr.Slider(0.5, 2.0, 1.0, step=0.1, label="Speaking Rate")
voice_btn = gr.Button("Generate Voiceover", variant="primary")
with gr.Row():
voice_audio = gr.Audio(label="Voice Preview", interactive=False)
voice_viz = gr.Image(label="Waveform", interactive=False)
with gr.Tab("🎡 Music Production"):
music_btn = gr.Button("Generate Music Track", variant="primary")
with gr.Row():
music_audio = gr.Audio(label="Music Preview", interactive=False)
music_viz = gr.Image(label="Waveform", interactive=False)
with gr.Tab("πŸ”‰ Final Mix"):
mix_btn = gr.Button("Create Final Mix", variant="primary")
with gr.Row():
final_mix_audio = gr.Audio(label="Final Mix", interactive=False)
final_mix_viz = gr.Image(label="Waveform", interactive=False)
with gr.Row():
download_btn = gr.Button("Download Mix")
play_btn = gr.Button("▢️ Play in Browser")
with gr.Accordion("πŸ“Š System Monitor", open=False):
monitor = gr.JSON(label="Resource Usage", value=lambda: system_monitor(), every=5)
gr.Markdown("""
<div style="text-align: center; padding: 20px; border-top: 1px solid #444;">
<p>Created with ❀️ by <a href="https://bilsimaging.com">Bils Imaging</a></p>
<img src="https://api.visitorbadge.io/api/visitors?path=https://huggingface.co/spaces/Bils/radiogold&countColor=%23263759">
</div>
""")
# Event Handling
generate_btn.click(
generate_script,
[concept_input, model_selector, duration_selector],
[script_output, sound_output, music_output]
)
voice_btn.click(
generate_voice,
[script_output, tts_selector, speed_selector],
[voice_audio, voice_viz],
preprocess=create_audio_visualization
)
music_btn.click(
generate_music,
[music_output],
[music_audio, music_viz],
preprocess=create_audio_visualization
)
mix_btn.click(
blend_audio,
[voice_audio, music_audio],
[final_mix_audio, final_mix_viz],
preprocess=create_audio_visualization
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)