Bils commited on
Commit
e7b189b
·
verified ·
1 Parent(s): a8c9cb5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -42
app.py CHANGED
@@ -11,62 +11,76 @@ from transformers import (
11
  from scipy.io.wavfile import write
12
  import tempfile
13
  from dotenv import load_dotenv
14
- import spaces # Assumes Hugging Face Spaces library supports `@spaces.GPU`
15
 
16
  # Load environment variables (e.g., Hugging Face token)
17
  load_dotenv()
18
  hf_token = os.getenv("HF_TOKEN")
19
 
 
 
 
 
 
20
  # ---------------------------------------------------------------------
21
- # Load Llama 3 Model with Zero GPU
22
  # ---------------------------------------------------------------------
23
  @spaces.GPU(duration=120)
24
  def load_llama_pipeline_zero_gpu(model_id: str, token: str):
25
- try:
26
- tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
27
- model = AutoModelForCausalLM.from_pretrained(
28
- model_id,
29
- use_auth_token=token,
30
- torch_dtype=torch.float16,
31
- device_map="auto", # Automatically handles GPU allocation
32
- trust_remote_code=True
33
- )
34
- return pipeline("text-generation", model=model, tokenizer=tokenizer)
35
- except Exception as e:
36
- return str(e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  # ---------------------------------------------------------------------
39
  # Generate Radio Script
40
  # ---------------------------------------------------------------------
41
- def generate_script(user_input: str, pipeline_llama):
42
  try:
43
  system_prompt = (
44
  "You are a top-tier radio imaging producer using Llama 3. "
45
  "Take the user's concept and craft a short, creative promo script."
46
  )
47
  combined_prompt = f"{system_prompt}\nUser concept: {user_input}\nRefined script:"
48
- result = pipeline_llama(combined_prompt, max_new_tokens=200, do_sample=True, temperature=0.9)
49
  return result[0]['generated_text'].split("Refined script:")[-1].strip()
50
  except Exception as e:
51
  return f"Error generating script: {e}"
52
 
53
- # ---------------------------------------------------------------------
54
- # Load MusicGen Model
55
- # ---------------------------------------------------------------------
56
- @spaces.GPU(duration=120)
57
- def load_musicgen_model():
58
- try:
59
- model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
60
- processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
61
- return model, processor
62
- except Exception as e:
63
- return None, str(e)
64
-
65
  # ---------------------------------------------------------------------
66
  # Generate Audio
67
  # ---------------------------------------------------------------------
68
  @spaces.GPU(duration=120)
69
- def generate_audio(prompt: str, audio_length: int, mg_model, mg_processor):
 
 
 
 
70
  try:
71
  mg_model.to("cuda") # Move the model to GPU
72
  inputs = mg_processor(text=[prompt], padding=True, return_tensors="pt")
@@ -87,24 +101,16 @@ def generate_audio(prompt: str, audio_length: int, mg_model, mg_processor):
87
  # Gradio Interface
88
  # ---------------------------------------------------------------------
89
  def radio_imaging_script(user_prompt, llama_model_id):
90
- # Load Llama 3 Pipeline with Zero GPU
91
- pipeline_llama = load_llama_pipeline_zero_gpu(llama_model_id, hf_token)
92
- if isinstance(pipeline_llama, str):
93
- return pipeline_llama
94
 
95
  # Generate Script
96
- script = generate_script(user_prompt, pipeline_llama)
97
  return script
98
 
99
  def radio_imaging_audio(script, audio_length):
100
- # Load MusicGen
101
- mg_model, mg_processor = load_musicgen_model()
102
- if isinstance(mg_processor, str):
103
- return mg_processor
104
-
105
- # Generate Audio
106
- audio_data = generate_audio(script, audio_length, mg_model, mg_processor)
107
- return audio_data
108
 
109
  # ---------------------------------------------------------------------
110
  # Interface
 
11
  from scipy.io.wavfile import write
12
  import tempfile
13
  from dotenv import load_dotenv
14
+ import spaces
15
 
16
  # Load environment variables (e.g., Hugging Face token)
17
  load_dotenv()
18
  hf_token = os.getenv("HF_TOKEN")
19
 
20
+ # Globals for Lazy Loading
21
+ llama_pipeline = None
22
+ musicgen_model = None
23
+ musicgen_processor = None
24
+
25
  # ---------------------------------------------------------------------
26
+ # Load Llama 3 Model with Zero GPU (Lazy Loading)
27
  # ---------------------------------------------------------------------
28
  @spaces.GPU(duration=120)
29
  def load_llama_pipeline_zero_gpu(model_id: str, token: str):
30
+ global llama_pipeline
31
+ if llama_pipeline is None:
32
+ try:
33
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
34
+ model = AutoModelForCausalLM.from_pretrained(
35
+ model_id,
36
+ use_auth_token=token,
37
+ torch_dtype=torch.float16,
38
+ device_map="auto", # Automatically handles GPU allocation
39
+ trust_remote_code=True
40
+ )
41
+ llama_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
42
+ except Exception as e:
43
+ return f"Error loading Llama pipeline: {e}"
44
+ return llama_pipeline
45
+
46
+ # ---------------------------------------------------------------------
47
+ # Load MusicGen Model (Lazy Loading)
48
+ # ---------------------------------------------------------------------
49
+ @spaces.GPU(duration=120)
50
+ def load_musicgen_model():
51
+ global musicgen_model, musicgen_processor
52
+ if musicgen_model is None or musicgen_processor is None:
53
+ try:
54
+ musicgen_model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small")
55
+ musicgen_processor = AutoProcessor.from_pretrained("facebook/musicgen-small")
56
+ except Exception as e:
57
+ return None, f"Error loading MusicGen model: {e}"
58
+ return musicgen_model, musicgen_processor
59
 
60
  # ---------------------------------------------------------------------
61
  # Generate Radio Script
62
  # ---------------------------------------------------------------------
63
+ def generate_script(user_input: str, llama_pipeline):
64
  try:
65
  system_prompt = (
66
  "You are a top-tier radio imaging producer using Llama 3. "
67
  "Take the user's concept and craft a short, creative promo script."
68
  )
69
  combined_prompt = f"{system_prompt}\nUser concept: {user_input}\nRefined script:"
70
+ result = llama_pipeline(combined_prompt, max_new_tokens=200, do_sample=True, temperature=0.9)
71
  return result[0]['generated_text'].split("Refined script:")[-1].strip()
72
  except Exception as e:
73
  return f"Error generating script: {e}"
74
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  # ---------------------------------------------------------------------
76
  # Generate Audio
77
  # ---------------------------------------------------------------------
78
  @spaces.GPU(duration=120)
79
+ def generate_audio(prompt: str, audio_length: int):
80
+ mg_model, mg_processor = load_musicgen_model()
81
+ if mg_model is None or isinstance(mg_processor, str):
82
+ return mg_processor
83
+
84
  try:
85
  mg_model.to("cuda") # Move the model to GPU
86
  inputs = mg_processor(text=[prompt], padding=True, return_tensors="pt")
 
101
  # Gradio Interface
102
  # ---------------------------------------------------------------------
103
  def radio_imaging_script(user_prompt, llama_model_id):
104
+ llama_pipeline = load_llama_pipeline_zero_gpu(llama_model_id, hf_token)
105
+ if isinstance(llama_pipeline, str):
106
+ return llama_pipeline
 
107
 
108
  # Generate Script
109
+ script = generate_script(user_prompt, llama_pipeline)
110
  return script
111
 
112
  def radio_imaging_audio(script, audio_length):
113
+ return generate_audio(script, audio_length)
 
 
 
 
 
 
 
114
 
115
  # ---------------------------------------------------------------------
116
  # Interface