akashmadisetty commited on
Commit
45c882e
·
1 Parent(s): c1d34f4
Files changed (1) hide show
  1. app.py +32 -30
app.py CHANGED
@@ -44,15 +44,12 @@ def load_model(hf_token):
44
  token=hf_token
45
  )
46
 
47
- # Load model with safe configuration
48
  global_model = AutoModelForCausalLM.from_pretrained(
49
  model_name,
50
  torch_dtype=torch.float16,
51
  device_map="auto",
52
- token=hf_token,
53
- use_cache=True,
54
- low_cpu_mem_usage=True,
55
- attn_implementation="flash_attention_2" if torch.cuda.is_available() else "eager"
56
  )
57
 
58
  model_loaded = True
@@ -162,28 +159,15 @@ def generate_text(prompt, max_length=1024, temperature=0.7, top_p=0.95):
162
  return "Please enter a prompt to generate text."
163
 
164
  try:
 
165
  inputs = global_tokenizer(prompt, return_tensors="pt").to(global_model.device)
166
 
167
- generation_config = {
168
- "max_length": max_length,
169
- "do_sample": True,
170
- "pad_token_id": global_tokenizer.eos_token_id,
171
- }
172
-
173
- # Only add temperature if it's not too low (can cause probability issues)
174
- if temperature >= 0.2:
175
- generation_config["temperature"] = temperature
176
- else:
177
- generation_config["temperature"] = 0.2
178
-
179
- # Only add top_p if it's valid
180
- if 0 < top_p < 1:
181
- generation_config["top_p"] = top_p
182
-
183
- # Generate text with safer parameters
184
  outputs = global_model.generate(
185
- **inputs,
186
- **generation_config
 
 
187
  )
188
 
189
  # Decode and return the generated text
@@ -191,8 +175,9 @@ def generate_text(prompt, max_length=1024, temperature=0.7, top_p=0.95):
191
  return generated_text
192
  except Exception as e:
193
  error_msg = str(e)
 
194
  if "probability tensor" in error_msg:
195
- return "Error: There was a problem with the generation parameters. Try using higher temperature (0.5+) and top_p values (0.9+)."
196
  else:
197
  return f"Error generating text: {error_msg}"
198
 
@@ -247,12 +232,27 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
247
  )
248
 
249
  with gr.Column(scale=1):
250
- auth_button = gr.Button("Authenticate")
251
 
252
- auth_status = gr.Markdown("Please authenticate to use the model.")
 
253
 
 
 
 
 
 
 
 
 
 
254
  auth_button.click(
255
- fn=load_model,
 
 
 
 
 
256
  inputs=[hf_token],
257
  outputs=[auth_status]
258
  )
@@ -1019,6 +1019,8 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
1019
 
1020
  # Load default token if available
1021
  if DEFAULT_HF_TOKEN:
1022
- demo.load(fn=load_model, inputs=[hf_token], outputs=[auth_status])
 
 
1023
 
1024
- demo.launch()
 
44
  token=hf_token
45
  )
46
 
47
+ # Load model with minimal configuration to avoid errors
48
  global_model = AutoModelForCausalLM.from_pretrained(
49
  model_name,
50
  torch_dtype=torch.float16,
51
  device_map="auto",
52
+ token=hf_token
 
 
 
53
  )
54
 
55
  model_loaded = True
 
159
  return "Please enter a prompt to generate text."
160
 
161
  try:
162
+ # Keep generation simple to avoid errors
163
  inputs = global_tokenizer(prompt, return_tensors="pt").to(global_model.device)
164
 
165
+ # Use simpler generation parameters that work reliably
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  outputs = global_model.generate(
167
+ inputs.input_ids,
168
+ max_length=min(2048, max_length + len(inputs.input_ids[0])),
169
+ temperature=max(0.3, temperature), # Prevent too low temperature
170
+ do_sample=True
171
  )
172
 
173
  # Decode and return the generated text
 
175
  return generated_text
176
  except Exception as e:
177
  error_msg = str(e)
178
+ print(f"Generation error: {error_msg}")
179
  if "probability tensor" in error_msg:
180
+ return "Error: There was a problem with the generation parameters. Try using simpler parameters or a different prompt."
181
  else:
182
  return f"Error generating text: {error_msg}"
183
 
 
232
  )
233
 
234
  with gr.Column(scale=1):
235
+ auth_button = gr.Button("Authenticate", variant="primary")
236
 
237
+ with gr.Group(visible=True) as auth_message_group:
238
+ auth_status = gr.Markdown("Please authenticate to use the model.")
239
 
240
+ def authenticate(token):
241
+ auth_message_group.visible = True
242
+ return "Loading model... Please wait, this may take a minute."
243
+
244
+ def auth_complete(token):
245
+ result = load_model(token)
246
+ return result
247
+
248
+ # Two-step authentication to show loading message
249
  auth_button.click(
250
+ fn=authenticate,
251
+ inputs=[hf_token],
252
+ outputs=[auth_status],
253
+ queue=False
254
+ ).then(
255
+ fn=auth_complete,
256
  inputs=[hf_token],
257
  outputs=[auth_status]
258
  )
 
1019
 
1020
  # Load default token if available
1021
  if DEFAULT_HF_TOKEN:
1022
+ demo.load(fn=authenticate, inputs=[hf_token], outputs=[auth_status]).then(
1023
+ fn=auth_complete, inputs=[hf_token], outputs=[auth_status]
1024
+ )
1025
 
1026
+ demo.launch(share=False)