AIPromoStudio / app.py
Bils's picture
Update app.py
ecc69bf verified
raw
history blame
7.24 kB
import gradio as gr
import os
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
pipeline,
AutoProcessor,
MusicgenForConditionalGeneration,
)
from scipy.io.wavfile import write
from TTS.api import TTS
import tempfile
from dotenv import load_dotenv
import spaces
# Load environment variables
load_dotenv()
hf_token = os.getenv("HF_TOKEN")
# ---------------------------------------------------------------------
# Load Llama 3 Pipeline with Zero GPU (Encapsulated)
# ---------------------------------------------------------------------
@spaces.GPU(duration=300)
def generate_script(user_prompt: str, duration: int, model_id: str, token: str):
try:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
model = AutoModelForCausalLM.from_pretrained(
model_id,
use_auth_token=token,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
)
llama_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
system_prompt = (
"You are an expert radio imaging producer specializing in sound design and music. "
f"Generate a concise, creative promo script for a {duration}-second ad, focusing on auditory elements and musical appeal."
)
combined_prompt = f"{system_prompt}\nUser concept: {user_prompt}\nRefined script:"
result = llama_pipeline(combined_prompt, max_new_tokens=200, do_sample=True, temperature=0.9)
return result[0]["generated_text"].split("Refined script:")[-1].strip()
except Exception as e:
return f"Error generating script: {e}"
# ---------------------------------------------------------------------
# Load MusicGen Model (Encapsulated)
# ---------------------------------------------------------------------
@spaces.GPU(duration=300)
def generate_audio(prompt: str, audio_length: int):
try:
musicgen_model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
musicgen_processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
device = "cuda" if torch.cuda.is_available() else "cpu"
musicgen_model.to(device)
inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt").to(device)
outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)
audio_data = outputs[0, 0].cpu().numpy()
normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16")
output_path = f"{tempfile.gettempdir()}/generated_audio.wav"
write(output_path, musicgen_model.config.audio_encoder.sampling_rate, normalized_audio)
return output_path
except Exception as e:
return f"Error generating audio: {e}"
# ---------------------------------------------------------------------
# Generate Voice-Over with Coqui XTTS-v2
# ---------------------------------------------------------------------
@spaces.GPU(duration=300)
def generate_voice(script: str, reference_audio: str, language: str):
try:
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=torch.cuda.is_available())
output_path = f"{tempfile.gettempdir()}/voice_over.wav"
tts.tts_to_file(
text=script,
file_path=output_path,
speaker_wav=reference_audio,
language=language,
)
return output_path
except Exception as e:
return f"Error generating voice-over: {e}"
# ---------------------------------------------------------------------
# Interface Functions
# ---------------------------------------------------------------------
def interface_generate_script(user_prompt, duration, llama_model_id):
return generate_script(user_prompt, duration, llama_model_id, hf_token)
def interface_generate_audio(script, audio_length):
return generate_audio(script, audio_length)
def interface_generate_voice(script, reference_audio, language):
return generate_voice(script, reference_audio, language)
# ---------------------------------------------------------------------
# Interface
# ---------------------------------------------------------------------
with gr.Blocks() as demo:
gr.Markdown("""
# 🎧 All-in-One Radio Promo Studio πŸš€
### Create professional scripts, soundscapes, and voice-overs in minutes!
πŸ”₯ Powered by **Llama 3**, **MusicGen**, and **XTTS-v2**
""")
# Script Generation Section
gr.Markdown("## ✍️ Step 1: Generate Your Promo Script")
with gr.Row():
user_prompt = gr.Textbox(
label="🎀 Enter Promo Idea",
placeholder="E.g., A 15-second energetic jingle for a morning talk show.",
lines=2
)
duration = gr.Dropdown(
label="⏳ Duration",
choices=["15", "30", "60"],
value="15",
info="Choose the duration of the promo (in seconds)."
)
llama_model_id = gr.Textbox(
label="πŸŽ›οΈ Llama 3 Model ID",
value="meta-llama/Meta-Llama-3-8B-Instruct"
)
generate_script_button = gr.Button("Generate Script ✨")
script_output = gr.Textbox(label="πŸ–ŒοΈ Generated Promo Script", lines=4, interactive=False)
# Audio Generation Section
gr.Markdown("## 🎡 Step 2: Generate Background Music")
with gr.Row():
audio_length = gr.Slider(
label="🎢 Audio Length (tokens)",
minimum=128,
maximum=1024,
step=64,
value=512
)
generate_audio_button = gr.Button("Generate Audio 🎢")
audio_output = gr.Audio(label="🎡 Generated Audio", type="filepath")
# Voice-Over Section
gr.Markdown("## πŸŽ™οΈ Step 3: Generate Voice-Over")
with gr.Row():
reference_audio = gr.Audio(
label="🎀 Upload Reference Voice (6 seconds)",
type="filepath"
)
language = gr.Dropdown(
label="🌍 Language",
choices=["en", "es", "fr", "de", "it"],
value="en"
)
generate_voice_button = gr.Button("Generate Voice-Over 🎀")
voice_output = gr.Audio(label="πŸ”Š Generated Voice-Over", type="filepath")
# Footer
gr.Markdown("""
<br><hr>
<p style="text-align: center; font-size: 0.9em;">
Created with ❀️ by <a href="https://bilsimaging.com" target="_blank">bilsimaging.com</a>
</p>
""")
# Button Actions
generate_script_button.click(
fn=interface_generate_script,
inputs=[user_prompt, duration, llama_model_id],
outputs=script_output
)
generate_audio_button.click(
fn=interface_generate_audio,
inputs=[script_output, audio_length],
outputs=audio_output
)
generate_voice_button.click(
fn=interface_generate_voice,
inputs=[script_output, reference_audio, language],
outputs=voice_output
)
# ---------------------------------------------------------------------
# Launch App
# ---------------------------------------------------------------------
demo.launch(debug=True)