Bils commited on
Commit
f0b5707
·
verified ·
1 Parent(s): 3257580

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -11,19 +11,19 @@ from transformers import (
11
  import scipy.io.wavfile as wav
12
 
13
  # ---------------------------------------------------------------------
14
- # Load Llama 3 Model
15
  # ---------------------------------------------------------------------
16
- def load_llama_pipeline(model_id: str, token: str, device: str = "cpu"):
17
  try:
18
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
19
  model = AutoModelForCausalLM.from_pretrained(
20
  model_id,
21
  use_auth_token=token,
22
- torch_dtype=torch.float16 if device == "cuda" else torch.float32,
23
- device_map="auto" if device == "cuda" else None,
24
- low_cpu_mem_usage=True
25
  )
26
- return pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if device == "cuda" else -1)
27
  except Exception as e:
28
  return str(e)
29
 
@@ -73,8 +73,8 @@ def generate_audio(prompt: str, audio_length: int, mg_model, mg_processor):
73
  # Gradio Interface
74
  # ---------------------------------------------------------------------
75
  def radio_imaging_app(user_prompt, llama_model_id, hf_token, audio_length):
76
- # Load Llama 3 Pipeline
77
- pipeline_llama = load_llama_pipeline(llama_model_id, hf_token, device="cuda" if torch.cuda.is_available() else "cpu")
78
  if isinstance(pipeline_llama, str):
79
  return pipeline_llama, None
80
 
@@ -97,7 +97,7 @@ def radio_imaging_app(user_prompt, llama_model_id, hf_token, audio_length):
97
  # Interface
98
  # ---------------------------------------------------------------------
99
  with gr.Blocks() as demo:
100
- gr.Markdown("# 🎧 AI Radio Imaging with Llama 3 + MusicGen")
101
  with gr.Row():
102
  user_prompt = gr.Textbox(label="Enter your promo idea", placeholder="E.g., A 15-second hype jingle for a morning talk show, fun and energetic.")
103
  llama_model_id = gr.Textbox(label="Llama 3 Model ID", value="meta-llama/Meta-Llama-3-70B")
 
11
  import scipy.io.wavfile as wav
12
 
13
  # ---------------------------------------------------------------------
14
+ # Load Llama 3 Model with Zero GPU
15
  # ---------------------------------------------------------------------
16
+ def load_llama_pipeline_zero_gpu(model_id: str, token: str):
17
  try:
18
  tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
19
  model = AutoModelForCausalLM.from_pretrained(
20
  model_id,
21
  use_auth_token=token,
22
+ torch_dtype=torch.float16,
23
+ device_map="auto", # Use device map to offload computations
24
+ trust_remote_code=True # Enables execution of remote code for Zero GPU
25
  )
26
+ return pipeline("text-generation", model=model, tokenizer=tokenizer)
27
  except Exception as e:
28
  return str(e)
29
 
 
73
  # Gradio Interface
74
  # ---------------------------------------------------------------------
75
  def radio_imaging_app(user_prompt, llama_model_id, hf_token, audio_length):
76
+ # Load Llama 3 Pipeline with Zero GPU
77
+ pipeline_llama = load_llama_pipeline_zero_gpu(llama_model_id, hf_token)
78
  if isinstance(pipeline_llama, str):
79
  return pipeline_llama, None
80
 
 
97
  # Interface
98
  # ---------------------------------------------------------------------
99
  with gr.Blocks() as demo:
100
+ gr.Markdown("# 🎧 AI Radio Imaging with Llama 3 + MusicGen (Zero GPU)")
101
  with gr.Row():
102
  user_prompt = gr.Textbox(label="Enter your promo idea", placeholder="E.g., A 15-second hype jingle for a morning talk show, fun and energetic.")
103
  llama_model_id = gr.Textbox(label="Llama 3 Model ID", value="meta-llama/Meta-Llama-3-70B")