AIPromoStudio / app.py
Bils's picture
Update app.py
ced3fa2 verified
raw
history blame
9.52 kB
import gradio as gr
import os
import torch
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
pipeline,
AutoProcessor,
MusicgenForConditionalGeneration,
)
from scipy.io.wavfile import write
from pydub import AudioSegment
from dotenv import load_dotenv
import tempfile
import spaces
from TTS.api import TTS
# -------------------------------
# Configuration
# -------------------------------
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
MODEL_CONFIG = {
"llama_models": {
"Meta-Llama-3-8B": "meta-llama/Meta-Llama-3-8B-Instruct",
"Mistral-7B": "mistralai/Mistral-7B-Instruct-v0.2",
},
"tts_models": {
"Standard English": "tts_models/en/ljspeech/tacotron2-DDC",
"High Quality": "tts_models/en/ljspeech/vits",
}
}
# -------------------------------
# Model Manager
# -------------------------------
class ModelManager:
def __init__(self):
self.llama_pipelines = {}
self.musicgen_models = {}
self.tts_models = {}
def get_llama_pipeline(self, model_id, token):
if model_id not in self.llama_pipelines:
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
model = AutoModelForCausalLM.from_pretrained(
model_id,
use_auth_token=token,
torch_dtype=torch.float16,
device_map="auto"
)
self.llama_pipelines[model_id] = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer
)
return self.llama_pipelines[model_id]
def get_musicgen_model(self, model_key="facebook/musicgen-large"):
if model_key not in self.musicgen_models:
model = MusicgenForConditionalGeneration.from_pretrained(model_key)
processor = AutoProcessor.from_pretrained(model_key)
self.musicgen_models[model_key] = (model, processor)
return self.musicgen_models[model_key]
def get_tts_model(self, model_name):
if model_name not in self.tts_models:
self.tts_models[model_name] = TTS(model_name)
return self.tts_models[model_name]
model_manager = ModelManager()
# -------------------------------
# Core Functions
# -------------------------------
@spaces.GPU
def generate_script(user_prompt, model_id, duration):
try:
text_pipeline = model_manager.get_llama_pipeline(model_id, HF_TOKEN)
system_prompt = f"""Create a {duration}-second audio promo with these elements:
1. Voice Script: [Clear narration]
2. Sound Design: [3-5 effects]
3. Music: [Genre/tempo]
Concept: {user_prompt}"""
result = text_pipeline(
system_prompt,
max_new_tokens=300,
temperature=0.7,
do_sample=True
)
generated_text = result[0]["generated_text"]
return parse_generated_content(generated_text)
except Exception as e:
return f"Error: {str(e)}", "", ""
def parse_generated_content(text):
sections = {
"Voice Script": "",
"Sound Design": "",
"Music": ""
}
current_section = None
for line in text.split('\n'):
line = line.strip()
if "Voice Script:" in line:
current_section = "Voice Script"
line = line.replace("Voice Script:", "").strip()
elif "Sound Design:" in line:
current_section = "Sound Design"
line = line.replace("Sound Design:", "").strip()
elif "Music:" in line:
current_section = "Music"
line = line.replace("Music:", "").strip()
if current_section and line:
sections[current_section] += line + "\n"
return sections["Voice Script"].strip(), sections["Sound Design"].strip(), sections["Music"].strip()
@spaces.GPU
def generate_voice(script, tts_model):
try:
if not script.strip():
return "Error: No script provided"
tts = model_manager.get_tts_model(tts_model)
output_path = os.path.join(tempfile.gettempdir(), "voice.wav")
tts.tts_to_file(text=script, file_path=output_path)
return output_path
except Exception as e:
return f"Error: {str(e)}"
@spaces.GPU
def generate_music(prompt, duration_sec=30):
try:
model, processor = model_manager.get_musicgen_model()
inputs = processor(text=[prompt], padding=True, return_tensors="pt")
audio_values = model.generate(**inputs, max_new_tokens=int(duration_sec * 50))
output_path = os.path.join(tempfile.gettempdir(), "music.wav")
write(output_path, 44100, audio_values[0, 0].cpu().numpy())
return output_path
except Exception as e:
return f"Error: {str(e)}"
def blend_audio(voice_path, music_path, ducking=True, duck_level=10):
try:
voice = AudioSegment.from_wav(voice_path)
music = AudioSegment.from_wav(music_path)
# Align durations
if len(music) < len(voice):
music = music * (len(voice) // len(music) + 1)
music = music[:len(voice)]
# Apply ducking
if ducking:
music = music - duck_level
mixed = music.overlay(voice)
output_path = os.path.join(tempfile.gettempdir(), "final_mix.wav")
mixed.export(output_path, format="wav")
return output_path
except Exception as e:
return f"Error: {str(e)}"
# -------------------------------
# Gradio Interface (Second UI Version)
# -------------------------------
with gr.Blocks(title="AI Radio Studio", css="""
.gradio-container {max-width: 800px; margin: auto;}
.tab-item {padding: 20px; border-radius: 10px;}
""") as demo:
gr.Markdown("""
# πŸŽ™οΈ AI Radio Studio
*Professional Audio Production Made Simple*
""")
with gr.Tabs():
# Concept Tab
with gr.Tab("🎯 Concept"):
with gr.Row():
with gr.Column():
concept_input = gr.Textbox(
label="Your Idea",
placeholder="Describe your audio project...",
lines=3
)
model_select = gr.Dropdown(
choices=list(MODEL_CONFIG["llama_models"].values()),
label="AI Model",
value="meta-llama/Meta-Llama-3-8B-Instruct"
)
duration_select = gr.Slider(15, 60, 30, step=15, label="Duration (seconds)")
generate_btn = gr.Button("Generate Script", variant="primary")
with gr.Column():
script_output = gr.Textbox(label="Voice Script", interactive=True)
sound_output = gr.Textbox(label="Sound Design", interactive=True)
music_output = gr.Textbox(label="Music Suggestions", interactive=True)
# Voice Tab
with gr.Tab("πŸ—£οΈ Voice"):
with gr.Row():
with gr.Column():
tts_select = gr.Dropdown(
choices=list(MODEL_CONFIG["tts_models"].values()),
label="Voice Model",
value="tts_models/en/ljspeech/tacotron2-DDC"
)
voice_btn = gr.Button("Generate Voiceover", variant="primary")
with gr.Column():
voice_preview = gr.Audio(label="Preview", type="filepath")
# Music Tab
with gr.Tab("🎡 Music"):
music_btn = gr.Button("Generate Music Track", variant="primary")
music_preview = gr.Audio(label="Preview", type="filepath")
# Mix Tab
with gr.Tab("πŸ”Š Mix"):
with gr.Row():
with gr.Column():
ducking_toggle = gr.Checkbox(True, label="Enable Voice Ducking")
duck_level = gr.Slider(0, 20, 10, label="Ducking Level (dB)")
mix_btn = gr.Button("Create Final Mix", variant="primary")
with gr.Column():
final_mix = gr.Audio(label="Final Output", type="filepath")
# Footer Section
gr.Markdown("""
<div style="text-align: center; margin-top: 30px; padding: 15px; border-top: 1px solid #e0e0e0;">
<p style="font-size: 0.9em; color: #666;">
Created with ❀️ by <a href="https://bilsimaging.com" target="_blank">bilsimaging.com</a>
</p>
<a href="https://visitorbadge.io/status?path=https://huggingface.co/spaces/Bils/radiogold">
<img src="https://api.visitorbadge.io/api/visitors?path=https%3A%2F%2Fhuggingface.co%2Fspaces%2FBils%2Fradiogold&countColor=%23263759"/>
</a>
</div>
""")
# Event Handlers
generate_btn.click(
generate_script,
inputs=[concept_input, model_select, duration_select],
outputs=[script_output, sound_output, music_output]
)
voice_btn.click(
generate_voice,
inputs=[script_output, tts_select],
outputs=voice_preview
)
music_btn.click(
generate_music,
inputs=[music_output],
outputs=music_preview
)
mix_btn.click(
blend_audio,
inputs=[voice_preview, music_preview, ducking_toggle, duck_level],
outputs=final_mix
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)