Bils commited on
Commit
8b6a33e
·
verified ·
1 Parent(s): e7b189b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -58
app.py CHANGED
@@ -11,13 +11,13 @@ from transformers import (
11
  from scipy.io.wavfile import write
12
  import tempfile
13
  from dotenv import load_dotenv
14
- import spaces
15
 
16
  # Load environment variables (e.g., Hugging Face token)
17
  load_dotenv()
18
  hf_token = os.getenv("HF_TOKEN")
19
 
20
- # Globals for Lazy Loading
21
  llama_pipeline = None
22
  musicgen_model = None
23
  musicgen_processor = None
@@ -25,12 +25,14 @@ musicgen_processor = None
25
  # ---------------------------------------------------------------------
26
  # Load Llama 3 Model with Zero GPU (Lazy Loading)
27
  # ---------------------------------------------------------------------
28
- @spaces.GPU(duration=120)
29
  def load_llama_pipeline_zero_gpu(model_id: str, token: str):
30
  global llama_pipeline
31
  if llama_pipeline is None:
32
  try:
 
33
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
 
34
  model = AutoModelForCausalLM.from_pretrained(
35
  model_id,
36
  use_auth_token=token,
@@ -38,56 +40,63 @@ def load_llama_pipeline_zero_gpu(model_id: str, token: str):
38
  device_map="auto", # Automatically handles GPU allocation
39
  trust_remote_code=True
40
  )
 
41
  llama_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
 
42
  except Exception as e:
43
- return f"Error loading Llama pipeline: {e}"
 
44
  return llama_pipeline
45
 
46
- # ---------------------------------------------------------------------
47
- # Load MusicGen Model (Lazy Loading)
48
- # ---------------------------------------------------------------------
49
- @spaces.GPU(duration=120)
50
- def load_musicgen_model():
51
- global musicgen_model, musicgen_processor
52
- if musicgen_model is None or musicgen_processor is None:
53
- try:
54
- musicgen_model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
55
- musicgen_processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
56
- except Exception as e:
57
- return None, f"Error loading MusicGen model: {e}"
58
- return musicgen_model, musicgen_processor
59
-
60
  # ---------------------------------------------------------------------
61
  # Generate Radio Script
62
  # ---------------------------------------------------------------------
63
- def generate_script(user_input: str, llama_pipeline):
64
  try:
65
  system_prompt = (
66
  "You are a top-tier radio imaging producer using Llama 3. "
67
  "Take the user's concept and craft a short, creative promo script."
68
  )
69
  combined_prompt = f"{system_prompt}\nUser concept: {user_input}\nRefined script:"
70
- result = llama_pipeline(combined_prompt, max_new_tokens=200, do_sample=True, temperature=0.9)
71
  return result[0]['generated_text'].split("Refined script:")[-1].strip()
72
  except Exception as e:
73
  return f"Error generating script: {e}"
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  # ---------------------------------------------------------------------
76
  # Generate Audio
77
  # ---------------------------------------------------------------------
78
- @spaces.GPU(duration=120)
79
  def generate_audio(prompt: str, audio_length: int):
80
- mg_model, mg_processor = load_musicgen_model()
81
- if mg_model is None or isinstance(mg_processor, str):
82
- return mg_processor
83
-
 
84
  try:
85
- mg_model.to("cuda") # Move the model to GPU
86
- inputs = mg_processor(text=[prompt], padding=True, return_tensors="pt")
87
- outputs = mg_model.generate(**inputs, max_new_tokens=audio_length)
88
- mg_model.to("cpu") # Return the model to CPU
89
 
90
- sr = mg_model.config.audio_encoder.sampling_rate
91
  audio_data = outputs[0, 0].cpu().numpy()
92
  normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16")
93
 
@@ -101,16 +110,19 @@ def generate_audio(prompt: str, audio_length: int):
101
  # Gradio Interface
102
  # ---------------------------------------------------------------------
103
  def radio_imaging_script(user_prompt, llama_model_id):
104
- llama_pipeline = load_llama_pipeline_zero_gpu(llama_model_id, hf_token)
105
- if isinstance(llama_pipeline, str):
106
- return llama_pipeline
 
107
 
108
  # Generate Script
109
- script = generate_script(user_prompt, llama_pipeline)
110
  return script
111
 
112
  def radio_imaging_audio(script, audio_length):
113
- return generate_audio(script, audio_length)
 
 
114
 
115
  # ---------------------------------------------------------------------
116
  # Interface
@@ -120,31 +132,29 @@ with gr.Blocks() as demo:
120
 
121
  # Script Generation Section
122
  with gr.Row():
123
- with gr.Column():
124
- gr.Markdown("## Step 1: Generate the Promo Script")
125
- user_prompt = gr.Textbox(label="Enter your promo idea", placeholder="E.g., A 15-second hype jingle for a morning talk show.")
126
- llama_model_id = gr.Textbox(label="Llama 3 Model ID", value="meta-llama/Meta-Llama-3-70B")
127
- generate_script_button = gr.Button("Generate Promo Script")
128
- script_output = gr.Textbox(label="Generated Script", interactive=False)
129
-
130
- generate_script_button.click(
131
- fn=radio_imaging_script,
132
- inputs=[user_prompt, llama_model_id],
133
- outputs=script_output
134
- )
135
 
136
  # Audio Generation Section
137
  with gr.Row():
138
- with gr.Column():
139
- gr.Markdown("## Step 2: Generate the Sound")
140
- audio_length = gr.Slider(label="Audio Length (tokens)", minimum=128, maximum=1024, step=64, value=512)
141
- generate_audio_button = gr.Button("Generate Sound from Script")
142
- audio_output = gr.Audio(label="Generated Audio", type="filepath")
143
-
144
- generate_audio_button.click(
145
- fn=radio_imaging_audio,
146
- inputs=[script_output, audio_length],
147
- outputs=audio_output
148
- )
149
 
150
  demo.launch(debug=True)
 
11
  from scipy.io.wavfile import write
12
  import tempfile
13
  from dotenv import load_dotenv
14
+ import spaces # Assumes Hugging Face Spaces library supports `@spaces.GPU`
15
 
16
  # Load environment variables (e.g., Hugging Face token)
17
  load_dotenv()
18
  hf_token = os.getenv("HF_TOKEN")
19
 
20
+ # Globals for lazy loading
21
  llama_pipeline = None
22
  musicgen_model = None
23
  musicgen_processor = None
 
25
  # ---------------------------------------------------------------------
26
  # Load Llama 3 Model with Zero GPU (Lazy Loading)
27
  # ---------------------------------------------------------------------
28
+ @spaces.GPU(duration=300) # Increased duration to 300 seconds
29
  def load_llama_pipeline_zero_gpu(model_id: str, token: str):
30
  global llama_pipeline
31
  if llama_pipeline is None:
32
  try:
33
+ print("Starting model loading...")
34
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
35
+ print("Tokenizer loaded.")
36
  model = AutoModelForCausalLM.from_pretrained(
37
  model_id,
38
  use_auth_token=token,
 
40
  device_map="auto", # Automatically handles GPU allocation
41
  trust_remote_code=True
42
  )
43
+ print("Model loaded. Initializing pipeline...")
44
  llama_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
45
+ print("Pipeline initialized successfully.")
46
  except Exception as e:
47
+ print(f"Error loading Llama pipeline: {e}")
48
+ return str(e)
49
  return llama_pipeline
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  # ---------------------------------------------------------------------
52
  # Generate Radio Script
53
  # ---------------------------------------------------------------------
54
+ def generate_script(user_input: str, pipeline_llama):
55
  try:
56
  system_prompt = (
57
  "You are a top-tier radio imaging producer using Llama 3. "
58
  "Take the user's concept and craft a short, creative promo script."
59
  )
60
  combined_prompt = f"{system_prompt}\nUser concept: {user_input}\nRefined script:"
61
+ result = pipeline_llama(combined_prompt, max_new_tokens=200, do_sample=True, temperature=0.9)
62
  return result[0]['generated_text'].split("Refined script:")[-1].strip()
63
  except Exception as e:
64
  return f"Error generating script: {e}"
65
 
66
+ # ---------------------------------------------------------------------
67
+ # Load MusicGen Model (Lazy Loading)
68
+ # ---------------------------------------------------------------------
69
+ @spaces.GPU(duration=300)
70
+ def load_musicgen_model():
71
+ global musicgen_model, musicgen_processor
72
+ if musicgen_model is None or musicgen_processor is None:
73
+ try:
74
+ print("Loading MusicGen model...")
75
+ musicgen_model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
76
+ musicgen_processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
77
+ print("MusicGen model loaded successfully.")
78
+ except Exception as e:
79
+ print(f"Error loading MusicGen model: {e}")
80
+ return None, str(e)
81
+ return musicgen_model, musicgen_processor
82
+
83
  # ---------------------------------------------------------------------
84
  # Generate Audio
85
  # ---------------------------------------------------------------------
86
+ @spaces.GPU(duration=300)
87
  def generate_audio(prompt: str, audio_length: int):
88
+ global musicgen_model, musicgen_processor
89
+ if musicgen_model is None or musicgen_processor is None:
90
+ musicgen_model, musicgen_processor = load_musicgen_model()
91
+ if isinstance(musicgen_model, str):
92
+ return musicgen_model
93
  try:
94
+ musicgen_model.to("cuda") # Move the model to GPU
95
+ inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt")
96
+ outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)
97
+ musicgen_model.to("cpu") # Return the model to CPU
98
 
99
+ sr = musicgen_model.config.audio_encoder.sampling_rate
100
  audio_data = outputs[0, 0].cpu().numpy()
101
  normalized_audio = (audio_data / max(abs(audio_data)) * 32767).astype("int16")
102
 
 
110
  # Gradio Interface
111
  # ---------------------------------------------------------------------
112
  def radio_imaging_script(user_prompt, llama_model_id):
113
+ # Load Llama 3 Pipeline with Zero GPU
114
+ pipeline_llama = load_llama_pipeline_zero_gpu(llama_model_id, hf_token)
115
+ if isinstance(pipeline_llama, str):
116
+ return pipeline_llama
117
 
118
  # Generate Script
119
+ script = generate_script(user_prompt, pipeline_llama)
120
  return script
121
 
122
  def radio_imaging_audio(script, audio_length):
123
+ # Generate Audio
124
+ audio_data = generate_audio(script, audio_length)
125
+ return audio_data
126
 
127
  # ---------------------------------------------------------------------
128
  # Interface
 
132
 
133
  # Script Generation Section
134
  with gr.Row():
135
+ gr.Markdown("## Step 1: Generate the Promo Script")
136
+ user_prompt = gr.Textbox(label="Enter your promo idea", placeholder="E.g., A 15-second hype jingle for a morning talk show.")
137
+ llama_model_id = gr.Textbox(label="Llama 3 Model ID", value="meta-llama/Meta-Llama-3-70B")
138
+ generate_script_button = gr.Button("Generate Promo Script")
139
+ script_output = gr.Textbox(label="Generated Script", interactive=False)
140
+
141
+ generate_script_button.click(
142
+ fn=radio_imaging_script,
143
+ inputs=[user_prompt, llama_model_id],
144
+ outputs=script_output
145
+ )
 
146
 
147
  # Audio Generation Section
148
  with gr.Row():
149
+ gr.Markdown("## Step 2: Generate the Sound")
150
+ audio_length = gr.Slider(label="Audio Length (tokens)", minimum=128, maximum=1024, step=64, value=512)
151
+ generate_audio_button = gr.Button("Generate Sound from Script")
152
+ audio_output = gr.Audio(label="Generated Audio", type="filepath")
153
+
154
+ generate_audio_button.click(
155
+ fn=radio_imaging_audio,
156
+ inputs=[script_output, audio_length],
157
+ outputs=audio_output
158
+ )
 
159
 
160
  demo.launch(debug=True)