Bils commited on
Commit
dfa5d3e
·
verified ·
1 Parent(s): b5ad742

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -5
app.py CHANGED
@@ -11,11 +11,15 @@ from transformers import (
11
  from scipy.io.wavfile import write
12
  import tempfile
13
  from dotenv import load_dotenv
14
- import spaces
15
 
 
16
  load_dotenv()
17
  hf_token = os.getenv("HF_TOKEN")
18
 
 
 
 
19
  @spaces.GPU(duration=120)
20
  def load_llama_pipeline_zero_gpu(model_id: str, token: str):
21
  try:
@@ -24,20 +28,50 @@ def load_llama_pipeline_zero_gpu(model_id: str, token: str):
24
  model_id,
25
  use_auth_token=token,
26
  torch_dtype=torch.float16,
27
- device_map="auto",
28
  trust_remote_code=True
29
  )
30
  return pipeline("text-generation", model=model, tokenizer=tokenizer)
31
  except Exception as e:
32
  return str(e)
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  @spaces.GPU(duration=120)
35
  def generate_audio(prompt: str, audio_length: int, mg_model, mg_processor):
36
  try:
37
- mg_model.to("cuda")
38
  inputs = mg_processor(text=[prompt], padding=True, return_tensors="pt")
39
  outputs = mg_model.generate(**inputs, max_new_tokens=audio_length)
40
- mg_model.to("cpu")
41
 
42
  sr = mg_model.config.audio_encoder.sampling_rate
43
  audio_data = outputs[0, 0].cpu().numpy()
@@ -49,6 +83,33 @@ def generate_audio(prompt: str, audio_length: int, mg_model, mg_processor):
49
  except Exception as e:
50
  return f"Error generating audio: {e}"
51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  with gr.Blocks() as demo:
53
  gr.Markdown("# 🎧 AI Radio Imaging with Llama 3 + MusicGen (Zero GPU)")
54
  user_prompt = gr.Textbox(label="Enter your promo idea", placeholder="E.g., A 15-second hype jingle for a morning talk show.")
@@ -61,7 +122,7 @@ with gr.Blocks() as demo:
61
  audio_output = gr.Audio(label="Generated Audio", type="filepath")
62
 
63
  generate_button.click(
64
- fn=lambda prompt, model_id, token, length: (prompt, None), # Simplify for demo
65
  inputs=[user_prompt, llama_model_id, hf_token, audio_length],
66
  outputs=[script_output, audio_output]
67
  )
 
11
  from scipy.io.wavfile import write
12
  import tempfile
13
  from dotenv import load_dotenv
14
+ import spaces # Assumes Hugging Face Spaces library supports `@spaces.GPU`
15
 
16
+ # Load environment variables (e.g., Hugging Face token)
17
  load_dotenv()
18
  hf_token = os.getenv("HF_TOKEN")
19
 
20
+ # ---------------------------------------------------------------------
21
+ # Load Llama 3 Model with Zero GPU
22
+ # ---------------------------------------------------------------------
23
  @spaces.GPU(duration=120)
24
  def load_llama_pipeline_zero_gpu(model_id: str, token: str):
25
  try:
 
28
  model_id,
29
  use_auth_token=token,
30
  torch_dtype=torch.float16,
31
+ device_map="auto", # Automatically handles GPU allocation
32
  trust_remote_code=True
33
  )
34
  return pipeline("text-generation", model=model, tokenizer=tokenizer)
35
  except Exception as e:
36
  return str(e)
37
 
38
+ # ---------------------------------------------------------------------
39
+ # Generate Radio Script
40
+ # ---------------------------------------------------------------------
41
+ def generate_script(user_input: str, pipeline_llama):
42
+ try:
43
+ system_prompt = (
44
+ "You are a top-tier radio imaging producer using Llama 3. "
45
+ "Take the user's concept and craft a short, creative promo script."
46
+ )
47
+ combined_prompt = f"{system_prompt}\nUser concept: {user_input}\nRefined script:"
48
+ result = pipeline_llama(combined_prompt, max_new_tokens=200, do_sample=True, temperature=0.9)
49
+ return result[0]['generated_text'].split("Refined script:")[-1].strip()
50
+ except Exception as e:
51
+ return f"Error generating script: {e}"
52
+
53
+ # ---------------------------------------------------------------------
54
+ # Load MusicGen Model
55
+ # ---------------------------------------------------------------------
56
+ @spaces.GPU(duration=120)
57
+ def load_musicgen_model():
58
+ try:
59
+ model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
60
+ processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
61
+ return model, processor
62
+ except Exception as e:
63
+ return None, str(e)
64
+
65
+ # ---------------------------------------------------------------------
66
+ # Generate Audio
67
+ # ---------------------------------------------------------------------
68
  @spaces.GPU(duration=120)
69
  def generate_audio(prompt: str, audio_length: int, mg_model, mg_processor):
70
  try:
71
+ mg_model.to("cuda") # Move the model to GPU
72
  inputs = mg_processor(text=[prompt], padding=True, return_tensors="pt")
73
  outputs = mg_model.generate(**inputs, max_new_tokens=audio_length)
74
+ mg_model.to("cpu") # Return the model to CPU
75
 
76
  sr = mg_model.config.audio_encoder.sampling_rate
77
  audio_data = outputs[0, 0].cpu().numpy()
 
83
  except Exception as e:
84
  return f"Error generating audio: {e}"
85
 
86
+ # ---------------------------------------------------------------------
87
+ # Gradio Interface
88
+ # ---------------------------------------------------------------------
89
+ def radio_imaging_app(user_prompt, llama_model_id, hf_token, audio_length):
90
+ # Load Llama 3 Pipeline with Zero GPU
91
+ pipeline_llama = load_llama_pipeline_zero_gpu(llama_model_id, hf_token)
92
+ if isinstance(pipeline_llama, str):
93
+ return pipeline_llama, None
94
+
95
+ # Generate Script
96
+ script = generate_script(user_prompt, pipeline_llama)
97
+
98
+ # Load MusicGen
99
+ mg_model, mg_processor = load_musicgen_model()
100
+ if isinstance(mg_processor, str):
101
+ return script, mg_processor
102
+
103
+ # Generate Audio
104
+ audio_data = generate_audio(script, audio_length, mg_model, mg_processor)
105
+ if isinstance(audio_data, str):
106
+ return script, audio_data
107
+
108
+ return script, audio_data
109
+
110
+ # ---------------------------------------------------------------------
111
+ # Interface
112
+ # ---------------------------------------------------------------------
113
  with gr.Blocks() as demo:
114
  gr.Markdown("# 🎧 AI Radio Imaging with Llama 3 + MusicGen (Zero GPU)")
115
  user_prompt = gr.Textbox(label="Enter your promo idea", placeholder="E.g., A 15-second hype jingle for a morning talk show.")
 
122
  audio_output = gr.Audio(label="Generated Audio", type="filepath")
123
 
124
  generate_button.click(
125
+ fn=radio_imaging_app,
126
  inputs=[user_prompt, llama_model_id, hf_token, audio_length],
127
  outputs=[script_output, audio_output]
128
  )