Bils commited on
Commit
613bd9e
·
verified ·
1 Parent(s): c909f5f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -15
app.py CHANGED
@@ -86,21 +86,25 @@ lottie_animation = load_lottie_url(LOTTIE_URL)
86
  # ---------------------------------------------------------------------
87
  @st.cache_resource
88
  def load_llama_pipeline(model_id: str, device: str, token: str):
89
- tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
90
- model = AutoModelForCausalLM.from_pretrained(
91
- model_id,
92
- use_auth_token=token,
93
- torch_dtype=torch.float16 if device == "auto" else torch.float32,
94
- device_map=device,
95
- low_cpu_mem_usage=True
96
- )
97
- text_gen_pipeline = pipeline(
98
- "text-generation",
99
- model=model,
100
- tokenizer=tokenizer,
101
- device_map=device
102
- )
103
- return text_gen_pipeline
 
 
 
 
104
 
105
  # ---------------------------------------------------------------------
106
  # 5) GENERATE RADIO SCRIPT
 
86
  # ---------------------------------------------------------------------
87
  @st.cache_resource
88
  def load_llama_pipeline(model_id: str, device: str, token: str):
89
+ try:
90
+ tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
91
+ model = AutoModelForCausalLM.from_pretrained(
92
+ model_id,
93
+ use_auth_token=token,
94
+ torch_dtype=torch.float16 if device == "auto" else torch.float32,
95
+ device_map=device,
96
+ low_cpu_mem_usage=True
97
+ )
98
+ text_gen_pipeline = pipeline(
99
+ "text-generation",
100
+ model=model,
101
+ tokenizer=tokenizer,
102
+ device_map=device
103
+ )
104
+ return text_gen_pipeline
105
+ except Exception as e:
106
+ st.error(f"Error loading Llama model: {e}")
107
+ raise
108
 
109
  # ---------------------------------------------------------------------
110
  # 5) GENERATE RADIO SCRIPT