AIPromoStudio / app.py
Bils's picture
Update app.py
ff307f8 verified
raw
history blame
7.87 kB
import os
import requests
import torch
import scipy.io.wavfile as wav
import streamlit as st
from io import BytesIO
from transformers import (
AutoTokenizer,
AutoModelForCausalLM,
pipeline,
AutoProcessor,
MusicgenForConditionalGeneration
)
from streamlit_lottie import st_lottie
# ---------------------------------------------------------------------
# 1) PAGE CONFIGURATION
# ---------------------------------------------------------------------
st.set_page_config(
page_title="AI Radio Imaging with Llama 3",
page_icon="🎧", # Use the actual emoji instead of Unicode escape
layout="wide"
)
# ---------------------------------------------------------------------
# 2) CUSTOM CSS / UI DESIGN
# ---------------------------------------------------------------------
CUSTOM_CSS = """
<style>
body {
background-color: #121212;
color: #FFFFFF;
font-family: "Helvetica Neue", sans-serif;
}
.block-container {
max-width: 1100px;
padding: 1rem 1.5rem;
}
h1, h2, h3 {
color: #1DB954;
}
.stButton>button {
background-color: #1DB954 !important;
color: #FFFFFF !important;
border-radius: 24px;
padding: 0.6rem 1.2rem;
}
.stButton>button:hover {
background-color: #1ed760 !important;
}
textarea, input, select {
border-radius: 8px !important;
background-color: #282828 !important;
color: #FFFFFF !important;
}
audio {
width: 100%;
margin-top: 1rem;
}
.footer-note {
text-align: center;
font-size: 14px;
opacity: 0.7;
margin-top: 2rem;
}
#MainMenu, footer {visibility: hidden;}
</style>
"""
st.markdown(CUSTOM_CSS, unsafe_allow_html=True)
# ---------------------------------------------------------------------
# 3) LOAD LOTTIE ANIMATION
# ---------------------------------------------------------------------
@st.cache_data
def load_lottie_url(url: str):
r = requests.get(url)
if r.status_code != 200:
return None
return r.json()
LOTTIE_URL = "https://assets3.lottiefiles.com/temp/lf20_Q6h5zV.json"
lottie_animation = load_lottie_url(LOTTIE_URL)
# ---------------------------------------------------------------------
# 4) LOAD LLAMA 3 (GATED MODEL)
# ---------------------------------------------------------------------
@st.cache_resource
def load_llama_pipeline(model_id: str, device: str, token: str):
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
model = AutoModelForCausalLM.from_pretrained(
model_id,
use_auth_token=token,
torch_dtype=torch.float16 if device == "auto" else torch.float32,
device_map=device
)
text_gen_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device_map=device
)
return text_gen_pipeline
# ---------------------------------------------------------------------
# 5) GENERATE RADIO SCRIPT
# ---------------------------------------------------------------------
def generate_radio_script(user_input: str, pipeline_llama) -> str:
system_prompt = (
"You are a top-tier radio imaging producer using Llama 3. "
"Take the user's concept and craft a short, creative promo script."
)
combined_prompt = f"{system_prompt}\nUser concept: {user_input}\nRefined script:"
result = pipeline_llama(
combined_prompt,
max_new_tokens=200,
do_sample=True,
temperature=0.9
)
output_text = result[0]["generated_text"]
if "Refined script:" in output_text:
output_text = output_text.split("Refined script:", 1)[-1].strip()
output_text += "\n\n(Generated by Llama 3 - Radio Imaging)"
return output_text
# ---------------------------------------------------------------------
# 6) LOAD MUSICGEN
# ---------------------------------------------------------------------
@st.cache_resource
def load_musicgen_model():
mg_model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
mg_processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
return mg_model, mg_processor
# ---------------------------------------------------------------------
# 7) HEADER
# ---------------------------------------------------------------------
st.title("\ud83c\udfa7 AI Radio Imaging with Llama 3")
st.subheader("Create engaging radio promos with Llama 3 + MusicGen")
st.markdown("""Create **radio imaging promos** and **jingles** easily. Ensure you have access to
**meta-llama/Meta-Llama-3-70B** on Hugging Face and provide your token below.""")
if lottie_animation:
st_lottie(lottie_animation, height=180, loop=True, key="radio_lottie")
st.markdown("---")
# ---------------------------------------------------------------------
# 8) USER INPUT
# ---------------------------------------------------------------------
st.subheader("\ud83c\udfa4 Step 1: Describe Your Promo Idea")
prompt = st.text_area(
"Example: 'A 15-second hype jingle for a morning talk show, fun and energetic.'",
height=120
)
col_model, col_device = st.columns(2)
with col_model:
llama_model_id = st.text_input(
"Llama 3 Model ID",
value="meta-llama/Meta-Llama-3-70B",
help="Enter the exact model ID from Hugging Face."
)
with col_device:
device_option = st.selectbox(
"Device",
["auto", "cpu"],
help="Choose GPU (auto) or CPU."
)
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
st.error("No HF_TOKEN found. Please set it in your environment.")
st.stop()
if st.button("\u270d Generate Promo Script"):
if not prompt.strip():
st.error("Please provide a concept first.")
else:
with st.spinner("Generating script..."):
try:
llama_pipeline = load_llama_pipeline(llama_model_id, device_option, hf_token)
final_script = generate_radio_script(prompt, llama_pipeline)
st.success("Promo script generated!")
st.text_area("Generated Script", value=final_script, height=200)
except Exception as e:
st.error(f"Llama generation error: {e}")
st.markdown("---")
# ---------------------------------------------------------------------
# 9) GENERATE AUDIO WITH MUSICGEN
# ---------------------------------------------------------------------
st.subheader("\ud83c\udfb5 Step 2: Generate Audio")
audio_length = st.slider("Track Length (tokens)", 128, 1024, 512, 64)
if st.button("\ud83c\udfa7 Create Audio"):
if "final_script" not in st.session_state:
st.error("Please generate a script first.")
else:
with st.spinner("Generating audio..."):
try:
mg_model, mg_processor = load_musicgen_model()
inputs = mg_processor(
text=[st.session_state["final_script"]],
padding=True,
return_tensors="pt"
)
audio_values = mg_model.generate(**inputs, max_new_tokens=audio_length)
sr = mg_model.config.audio_encoder.sampling_rate
output_file = "radio_jingle.wav"
audio_data = audio_values[0, 0].cpu().numpy()
normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16")
wav.write(output_file, rate=sr, data=normalized_audio)
st.success("Audio generated! Play it below:")
st.audio(output_file)
except Exception as e:
st.error(f"MusicGen error: {e}")
# ---------------------------------------------------------------------
# 10) FOOTER
# ---------------------------------------------------------------------
st.markdown("---")
st.markdown(
"""
<div class="footer-note">
© 2025 AI Radio Imaging – Built with Hugging Face & Streamlit
</div>
""",
unsafe_allow_html=True
)