Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import requests | |
import torch | |
import scipy.io.wavfile as wav | |
import streamlit as st | |
from io import BytesIO | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForCausalLM, | |
pipeline, | |
AutoProcessor, | |
MusicgenForConditionalGeneration | |
) | |
from streamlit_lottie import st_lottie | |
# --------------------------------------------------------------------- | |
# 1) PAGE CONFIGURATION | |
# --------------------------------------------------------------------- | |
st.set_page_config( | |
page_title="AI Radio Imaging with Llama 3", | |
page_icon="🎧", # Use the actual emoji instead of Unicode escape | |
layout="wide" | |
) | |
# --------------------------------------------------------------------- | |
# 2) CUSTOM CSS / UI DESIGN | |
# --------------------------------------------------------------------- | |
CUSTOM_CSS = """ | |
<style> | |
body { | |
background-color: #121212; | |
color: #FFFFFF; | |
font-family: "Helvetica Neue", sans-serif; | |
} | |
.block-container { | |
max-width: 1100px; | |
padding: 1rem 1.5rem; | |
} | |
h1, h2, h3 { | |
color: #1DB954; | |
} | |
.stButton>button { | |
background-color: #1DB954 !important; | |
color: #FFFFFF !important; | |
border-radius: 24px; | |
padding: 0.6rem 1.2rem; | |
} | |
.stButton>button:hover { | |
background-color: #1ed760 !important; | |
} | |
textarea, input, select { | |
border-radius: 8px !important; | |
background-color: #282828 !important; | |
color: #FFFFFF !important; | |
} | |
audio { | |
width: 100%; | |
margin-top: 1rem; | |
} | |
.footer-note { | |
text-align: center; | |
font-size: 14px; | |
opacity: 0.7; | |
margin-top: 2rem; | |
} | |
#MainMenu, footer {visibility: hidden;} | |
</style> | |
""" | |
st.markdown(CUSTOM_CSS, unsafe_allow_html=True) | |
# --------------------------------------------------------------------- | |
# 3) LOAD LOTTIE ANIMATION | |
# --------------------------------------------------------------------- | |
def load_lottie_url(url: str): | |
r = requests.get(url) | |
if r.status_code != 200: | |
return None | |
return r.json() | |
LOTTIE_URL = "https://assets3.lottiefiles.com/temp/lf20_Q6h5zV.json" | |
lottie_animation = load_lottie_url(LOTTIE_URL) | |
# --------------------------------------------------------------------- | |
# 4) LOAD LLAMA 3 (GATED MODEL) | |
# --------------------------------------------------------------------- | |
def load_llama_pipeline(model_id: str, device: str, token: str): | |
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
use_auth_token=token, | |
torch_dtype=torch.float16 if device == "auto" else torch.float32, | |
device_map=device | |
) | |
text_gen_pipeline = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
device_map=device | |
) | |
return text_gen_pipeline | |
# --------------------------------------------------------------------- | |
# 5) GENERATE RADIO SCRIPT | |
# --------------------------------------------------------------------- | |
def generate_radio_script(user_input: str, pipeline_llama) -> str: | |
system_prompt = ( | |
"You are a top-tier radio imaging producer using Llama 3. " | |
"Take the user's concept and craft a short, creative promo script." | |
) | |
combined_prompt = f"{system_prompt}\nUser concept: {user_input}\nRefined script:" | |
result = pipeline_llama( | |
combined_prompt, | |
max_new_tokens=200, | |
do_sample=True, | |
temperature=0.9 | |
) | |
output_text = result[0]["generated_text"] | |
if "Refined script:" in output_text: | |
output_text = output_text.split("Refined script:", 1)[-1].strip() | |
output_text += "\n\n(Generated by Llama 3 - Radio Imaging)" | |
return output_text | |
# --------------------------------------------------------------------- | |
# 6) LOAD MUSICGEN | |
# --------------------------------------------------------------------- | |
def load_musicgen_model(): | |
mg_model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small") | |
mg_processor = AutoProcessor.from_pretrained("facebook/musicgen-small") | |
return mg_model, mg_processor | |
# --------------------------------------------------------------------- | |
# 7) HEADER | |
# --------------------------------------------------------------------- | |
st.title("\ud83c\udfa7 AI Radio Imaging with Llama 3") | |
st.subheader("Create engaging radio promos with Llama 3 + MusicGen") | |
st.markdown("""Create **radio imaging promos** and **jingles** easily. Ensure you have access to | |
**meta-llama/Meta-Llama-3-70B** on Hugging Face and provide your token below.""") | |
if lottie_animation: | |
st_lottie(lottie_animation, height=180, loop=True, key="radio_lottie") | |
st.markdown("---") | |
# --------------------------------------------------------------------- | |
# 8) USER INPUT | |
# --------------------------------------------------------------------- | |
st.subheader("\ud83c\udfa4 Step 1: Describe Your Promo Idea") | |
prompt = st.text_area( | |
"Example: 'A 15-second hype jingle for a morning talk show, fun and energetic.'", | |
height=120 | |
) | |
col_model, col_device = st.columns(2) | |
with col_model: | |
llama_model_id = st.text_input( | |
"Llama 3 Model ID", | |
value="meta-llama/Meta-Llama-3-70B", | |
help="Enter the exact model ID from Hugging Face." | |
) | |
with col_device: | |
device_option = st.selectbox( | |
"Device", | |
["auto", "cpu"], | |
help="Choose GPU (auto) or CPU." | |
) | |
hf_token = os.getenv("HF_TOKEN") | |
if not hf_token: | |
st.error("No HF_TOKEN found. Please set it in your environment.") | |
st.stop() | |
if st.button("\u270d Generate Promo Script"): | |
if not prompt.strip(): | |
st.error("Please provide a concept first.") | |
else: | |
with st.spinner("Generating script..."): | |
try: | |
llama_pipeline = load_llama_pipeline(llama_model_id, device_option, hf_token) | |
final_script = generate_radio_script(prompt, llama_pipeline) | |
st.success("Promo script generated!") | |
st.text_area("Generated Script", value=final_script, height=200) | |
except Exception as e: | |
st.error(f"Llama generation error: {e}") | |
st.markdown("---") | |
# --------------------------------------------------------------------- | |
# 9) GENERATE AUDIO WITH MUSICGEN | |
# --------------------------------------------------------------------- | |
st.subheader("\ud83c\udfb5 Step 2: Generate Audio") | |
audio_length = st.slider("Track Length (tokens)", 128, 1024, 512, 64) | |
if st.button("\ud83c\udfa7 Create Audio"): | |
if "final_script" not in st.session_state: | |
st.error("Please generate a script first.") | |
else: | |
with st.spinner("Generating audio..."): | |
try: | |
mg_model, mg_processor = load_musicgen_model() | |
inputs = mg_processor( | |
text=[st.session_state["final_script"]], | |
padding=True, | |
return_tensors="pt" | |
) | |
audio_values = mg_model.generate(**inputs, max_new_tokens=audio_length) | |
sr = mg_model.config.audio_encoder.sampling_rate | |
output_file = "radio_jingle.wav" | |
audio_data = audio_values[0, 0].cpu().numpy() | |
normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16") | |
wav.write(output_file, rate=sr, data=normalized_audio) | |
st.success("Audio generated! Play it below:") | |
st.audio(output_file) | |
except Exception as e: | |
st.error(f"MusicGen error: {e}") | |
# --------------------------------------------------------------------- | |
# 10) FOOTER | |
# --------------------------------------------------------------------- | |
st.markdown("---") | |
st.markdown( | |
""" | |
<div class="footer-note"> | |
© 2025 AI Radio Imaging – Built with Hugging Face & Streamlit | |
</div> | |
""", | |
unsafe_allow_html=True | |
) | |