Bils commited on
Commit
b6700b8
Β·
verified Β·
1 Parent(s): 961b217

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -15
app.py CHANGED
@@ -13,13 +13,23 @@ import tempfile
13
  from dotenv import load_dotenv
14
  import spaces
15
 
 
16
  load_dotenv()
17
  hf_token = os.getenv("HF_TOKEN")
18
 
 
 
 
 
 
 
 
 
 
19
  # ---------------------------------------------------------------------
20
  # Load Llama 3 Pipeline with Zero GPU (Encapsulated)
21
  # ---------------------------------------------------------------------
22
- @spaces.GPU(duration=300)
23
  def generate_script(user_prompt: str, model_id: str, token: str):
24
  try:
25
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
@@ -43,7 +53,6 @@ def generate_script(user_prompt: str, model_id: str, token: str):
43
  except Exception as e:
44
  return f"Error generating script: {e}"
45
 
46
-
47
  # ---------------------------------------------------------------------
48
  # Load MusicGen Model (Encapsulated)
49
  # ---------------------------------------------------------------------
@@ -56,7 +65,7 @@ def generate_audio(prompt: str, audio_length: int):
56
  musicgen_model.to("cuda")
57
  inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt")
58
  outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)
59
- musicgen_model.to("cpu")
60
 
61
  sr = musicgen_model.config.audio_encoder.sampling_rate
62
  audio_data = outputs[0, 0].cpu().numpy()
@@ -69,18 +78,15 @@ def generate_audio(prompt: str, audio_length: int):
69
  except Exception as e:
70
  return f"Error generating audio: {e}"
71
 
72
-
73
  # ---------------------------------------------------------------------
74
  # Gradio Interface Functions
75
  # ---------------------------------------------------------------------
76
  def interface_generate_script(user_prompt, llama_model_id):
77
  return generate_script(user_prompt, llama_model_id, hf_token)
78
 
79
-
80
  def interface_generate_audio(script, audio_length):
81
  return generate_audio(script, audio_length)
82
 
83
-
84
  # ---------------------------------------------------------------------
85
  # Interface
86
  # ---------------------------------------------------------------------
@@ -90,7 +96,7 @@ with gr.Blocks() as demo:
90
  # πŸŽ™οΈ AI-Powered Radio Imaging Studio πŸš€
91
  ### Create stunning **radio promos** with **Llama 3** and **MusicGen**
92
  πŸ”₯ **Zero GPU** integration for efficiency and ease!
93
- ❀️ A huge thanks to the **Hugging Face community** for making this possible.
94
  """)
95
 
96
  # Script Generation Section
@@ -126,19 +132,18 @@ with gr.Blocks() as demo:
126
  value=512,
127
  info="Select the desired audio token length."
128
  )
129
- generate_audio_button = gr.Button("Generate Audio 🎢")
130
- audio_output = gr.Audio(
131
- label="🎢 Generated Audio File",
132
- type="filepath",
133
- interactive=False
134
- )
135
 
136
  # Footer
137
  gr.Markdown("""
138
  <br><hr>
139
  <p style="text-align: center; font-size: 0.9em;">
140
- Created with ❀️ by <a href="https://bilsimaging.com" target="_blank">bilsimaging.com</a>
141
- Special thanks to the <strong>Hugging Face community</strong> for their incredible support and tools!
142
  </p>
143
  """, elem_id="footer")
144
 
 
13
  from dotenv import load_dotenv
14
  import spaces
15
 
16
+ # Load environment variables
17
  load_dotenv()
18
  hf_token = os.getenv("HF_TOKEN")
19
 
20
+ # Check and enable Xformers for memory-efficient attention
21
+ if torch.cuda.is_available():
22
+ try:
23
+ from xformers.ops import memory_efficient_attention
24
+ os.environ["XFORMERS_ATTENTION"] = "1"
25
+ print("Xformers is enabled for memory-efficient attention.")
26
+ except ImportError:
27
+ print("Xformers is not installed or could not be imported.")
28
+
29
  # ---------------------------------------------------------------------
30
  # Load Llama 3 Pipeline with Zero GPU (Encapsulated)
31
  # ---------------------------------------------------------------------
32
+ @spaces.GPU(duration=300) # GPU allocation for 300 seconds
33
  def generate_script(user_prompt: str, model_id: str, token: str):
34
  try:
35
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
 
53
  except Exception as e:
54
  return f"Error generating script: {e}"
55
 
 
56
  # ---------------------------------------------------------------------
57
  # Load MusicGen Model (Encapsulated)
58
  # ---------------------------------------------------------------------
 
65
  musicgen_model.to("cuda")
66
  inputs = musicgen_processor(text=[prompt], padding=True, return_tensors="pt")
67
  outputs = musicgen_model.generate(**inputs, max_new_tokens=audio_length)
68
+ musicgen_model.to("cpu") # Return the model to CPU
69
 
70
  sr = musicgen_model.config.audio_encoder.sampling_rate
71
  audio_data = outputs[0, 0].cpu().numpy()
 
78
  except Exception as e:
79
  return f"Error generating audio: {e}"
80
 
 
81
  # ---------------------------------------------------------------------
82
  # Gradio Interface Functions
83
  # ---------------------------------------------------------------------
84
  def interface_generate_script(user_prompt, llama_model_id):
85
  return generate_script(user_prompt, llama_model_id, hf_token)
86
 
 
87
  def interface_generate_audio(script, audio_length):
88
  return generate_audio(script, audio_length)
89
 
 
90
  # ---------------------------------------------------------------------
91
  # Interface
92
  # ---------------------------------------------------------------------
 
96
  # πŸŽ™οΈ AI-Powered Radio Imaging Studio πŸš€
97
  ### Create stunning **radio promos** with **Llama 3** and **MusicGen**
98
  πŸ”₯ **Zero GPU** integration for efficiency and ease!
99
+ πŸ™Œ Thanks to the Hugging Face community for supporting this innovation.
100
  """)
101
 
102
  # Script Generation Section
 
132
  value=512,
133
  info="Select the desired audio token length."
134
  )
135
+ generate_audio_button = gr.Button("Generate Audio 🎢")
136
+ audio_output = gr.Audio(
137
+ label="🎢 Generated Audio File",
138
+ type="filepath",
139
+ interactive=False
140
+ )
141
 
142
  # Footer
143
  gr.Markdown("""
144
  <br><hr>
145
  <p style="text-align: center; font-size: 0.9em;">
146
+ Created with ❀️ by <a href="https://bilsimaging.com" target="_blank">bilsimaging.com</a>
 
147
  </p>
148
  """, elem_id="footer")
149