Spaces:
Running
Running
Commit
·
45c882e
1
Parent(s):
c1d34f4
OM
Browse files
app.py
CHANGED
@@ -44,15 +44,12 @@ def load_model(hf_token):
|
|
44 |
token=hf_token
|
45 |
)
|
46 |
|
47 |
-
# Load model with
|
48 |
global_model = AutoModelForCausalLM.from_pretrained(
|
49 |
model_name,
|
50 |
torch_dtype=torch.float16,
|
51 |
device_map="auto",
|
52 |
-
token=hf_token
|
53 |
-
use_cache=True,
|
54 |
-
low_cpu_mem_usage=True,
|
55 |
-
attn_implementation="flash_attention_2" if torch.cuda.is_available() else "eager"
|
56 |
)
|
57 |
|
58 |
model_loaded = True
|
@@ -162,28 +159,15 @@ def generate_text(prompt, max_length=1024, temperature=0.7, top_p=0.95):
|
|
162 |
return "Please enter a prompt to generate text."
|
163 |
|
164 |
try:
|
|
|
165 |
inputs = global_tokenizer(prompt, return_tensors="pt").to(global_model.device)
|
166 |
|
167 |
-
|
168 |
-
"max_length": max_length,
|
169 |
-
"do_sample": True,
|
170 |
-
"pad_token_id": global_tokenizer.eos_token_id,
|
171 |
-
}
|
172 |
-
|
173 |
-
# Only add temperature if it's not too low (can cause probability issues)
|
174 |
-
if temperature >= 0.2:
|
175 |
-
generation_config["temperature"] = temperature
|
176 |
-
else:
|
177 |
-
generation_config["temperature"] = 0.2
|
178 |
-
|
179 |
-
# Only add top_p if it's valid
|
180 |
-
if 0 < top_p < 1:
|
181 |
-
generation_config["top_p"] = top_p
|
182 |
-
|
183 |
-
# Generate text with safer parameters
|
184 |
outputs = global_model.generate(
|
185 |
-
|
186 |
-
|
|
|
|
|
187 |
)
|
188 |
|
189 |
# Decode and return the generated text
|
@@ -191,8 +175,9 @@ def generate_text(prompt, max_length=1024, temperature=0.7, top_p=0.95):
|
|
191 |
return generated_text
|
192 |
except Exception as e:
|
193 |
error_msg = str(e)
|
|
|
194 |
if "probability tensor" in error_msg:
|
195 |
-
return "Error: There was a problem with the generation parameters. Try using
|
196 |
else:
|
197 |
return f"Error generating text: {error_msg}"
|
198 |
|
@@ -247,12 +232,27 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
|
247 |
)
|
248 |
|
249 |
with gr.Column(scale=1):
|
250 |
-
auth_button = gr.Button("Authenticate")
|
251 |
|
252 |
-
|
|
|
253 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
254 |
auth_button.click(
|
255 |
-
fn=
|
|
|
|
|
|
|
|
|
|
|
256 |
inputs=[hf_token],
|
257 |
outputs=[auth_status]
|
258 |
)
|
@@ -1019,6 +1019,8 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
|
1019 |
|
1020 |
# Load default token if available
|
1021 |
if DEFAULT_HF_TOKEN:
|
1022 |
-
demo.load(fn=
|
|
|
|
|
1023 |
|
1024 |
-
demo.launch()
|
|
|
44 |
token=hf_token
|
45 |
)
|
46 |
|
47 |
+
# Load model with minimal configuration to avoid errors
|
48 |
global_model = AutoModelForCausalLM.from_pretrained(
|
49 |
model_name,
|
50 |
torch_dtype=torch.float16,
|
51 |
device_map="auto",
|
52 |
+
token=hf_token
|
|
|
|
|
|
|
53 |
)
|
54 |
|
55 |
model_loaded = True
|
|
|
159 |
return "Please enter a prompt to generate text."
|
160 |
|
161 |
try:
|
162 |
+
# Keep generation simple to avoid errors
|
163 |
inputs = global_tokenizer(prompt, return_tensors="pt").to(global_model.device)
|
164 |
|
165 |
+
# Use simpler generation parameters that work reliably
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
outputs = global_model.generate(
|
167 |
+
inputs.input_ids,
|
168 |
+
max_length=min(2048, max_length + len(inputs.input_ids[0])),
|
169 |
+
temperature=max(0.3, temperature), # Prevent too low temperature
|
170 |
+
do_sample=True
|
171 |
)
|
172 |
|
173 |
# Decode and return the generated text
|
|
|
175 |
return generated_text
|
176 |
except Exception as e:
|
177 |
error_msg = str(e)
|
178 |
+
print(f"Generation error: {error_msg}")
|
179 |
if "probability tensor" in error_msg:
|
180 |
+
return "Error: There was a problem with the generation parameters. Try using simpler parameters or a different prompt."
|
181 |
else:
|
182 |
return f"Error generating text: {error_msg}"
|
183 |
|
|
|
232 |
)
|
233 |
|
234 |
with gr.Column(scale=1):
|
235 |
+
auth_button = gr.Button("Authenticate", variant="primary")
|
236 |
|
237 |
+
with gr.Group(visible=True) as auth_message_group:
|
238 |
+
auth_status = gr.Markdown("Please authenticate to use the model.")
|
239 |
|
240 |
+
def authenticate(token):
|
241 |
+
auth_message_group.visible = True
|
242 |
+
return "Loading model... Please wait, this may take a minute."
|
243 |
+
|
244 |
+
def auth_complete(token):
|
245 |
+
result = load_model(token)
|
246 |
+
return result
|
247 |
+
|
248 |
+
# Two-step authentication to show loading message
|
249 |
auth_button.click(
|
250 |
+
fn=authenticate,
|
251 |
+
inputs=[hf_token],
|
252 |
+
outputs=[auth_status],
|
253 |
+
queue=False
|
254 |
+
).then(
|
255 |
+
fn=auth_complete,
|
256 |
inputs=[hf_token],
|
257 |
outputs=[auth_status]
|
258 |
)
|
|
|
1019 |
|
1020 |
# Load default token if available
|
1021 |
if DEFAULT_HF_TOKEN:
|
1022 |
+
demo.load(fn=authenticate, inputs=[hf_token], outputs=[auth_status]).then(
|
1023 |
+
fn=auth_complete, inputs=[hf_token], outputs=[auth_status]
|
1024 |
+
)
|
1025 |
|
1026 |
+
demo.launch(share=False)
|