hmrizal commited on
Commit
bc3e7d7
·
verified ·
1 Parent(s): 0cca13a

remove fallback model completely, and uncomment

Browse files
Files changed (1) hide show
  1. app.py +24 -143
app.py CHANGED
@@ -78,21 +78,9 @@ MODEL_CONFIG = {
78
  "description": "Lightweight T5 model optimized for instruction following",
79
  "dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
80
  "is_t5": True
81
- },
82
- "Fallback Model": {
83
- "name": "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
84
- "description": "Model sangat ringan untuk fallback",
85
- "dtype": torch.float16 if torch.cuda.is_available() else torch.float32
86
  }
87
  }
88
 
89
- # Tambahkan model fallback ke MODEL_CONFIG
90
- # MODEL_CONFIG["Fallback Model"] = {
91
- # "name": "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T",
92
- # "description": "Model sangat ringan untuk fallback",
93
- # "dtype": torch.float16 if torch.cuda.is_available() else torch.float32
94
- # }
95
-
96
  def initialize_model_once(model_key):
97
  with MODEL_CACHE["init_lock"]:
98
  current_model = MODEL_CACHE["model_name"]
@@ -160,35 +148,21 @@ def initialize_model_once(model_key):
160
 
161
  # Handle standard HF models
162
  else:
 
 
 
 
 
 
163
  MODEL_CACHE["tokenizer"] = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
164
-
165
- # Only use quantization if CUDA is available
166
- if torch.cuda.is_available():
167
- quantization_config = BitsAndBytesConfig(
168
- load_in_4bit=True,
169
- bnb_4bit_compute_dtype=torch.float16,
170
- bnb_4bit_quant_type="nf4",
171
- bnb_4bit_use_double_quant=True
172
- )
173
-
174
- MODEL_CACHE["model"] = AutoModelForCausalLM.from_pretrained(
175
- model_name,
176
- quantization_config=quantization_config,
177
- torch_dtype=model_info["dtype"],
178
- device_map="auto",
179
- low_cpu_mem_usage=True,
180
- trust_remote_code=True
181
- )
182
- else:
183
- # For CPU-only environments, load without quantization
184
- MODEL_CACHE["model"] = AutoModelForCausalLM.from_pretrained(
185
- model_name,
186
- torch_dtype=torch.float32, # Use float32 for CPU
187
- device_map=None,
188
- low_cpu_mem_usage=True,
189
- trust_remote_code=True
190
- )
191
-
192
  MODEL_CACHE["is_gguf"] = False
193
 
194
  print(f"Model {model_name} loaded successfully")
@@ -206,9 +180,6 @@ def create_llm_pipeline(model_key):
206
  print(f"Creating pipeline for model: {model_key}")
207
  tokenizer, model, is_gguf = initialize_model_once(model_key)
208
 
209
- # Get the model info for reference
210
- model_info = MODEL_CONFIG[model_key]
211
-
212
  if model is None:
213
  raise ValueError(f"Model is None for {model_key}")
214
 
@@ -258,85 +229,22 @@ def create_llm_pipeline(model_key):
258
  import traceback
259
  print(f"Error creating pipeline: {str(e)}")
260
  print(traceback.format_exc())
261
-
262
- # Fallback ke model sederhana jika yang utama gagal
263
- if model_key != "Fallback Model":
264
- print(f"Trying fallback model")
265
- try:
266
- return create_fallback_pipeline()
267
- except:
268
- raise RuntimeError(f"Failed to create pipeline: {str(e)}")
269
- else:
270
- raise RuntimeError(f"Failed to create pipeline: {str(e)}")
271
-
272
- def create_fallback_pipeline():
273
- """Create a fallback pipeline with a very small model"""
274
- model_key = "Fallback Model"
275
- print(f"Creating minimal fallback pipeline with {MODEL_CONFIG[model_key]['name']}")
276
-
277
- # Avoid using bitsandbytes for quantization when CUDA is not available
278
- try:
279
- tokenizer = AutoTokenizer.from_pretrained(MODEL_CONFIG[model_key]["name"])
280
-
281
- # Load model in 8-bit or without quantization for CPU
282
- if torch.cuda.is_available():
283
- model = AutoModelForCausalLM.from_pretrained(
284
- MODEL_CONFIG[model_key]["name"],
285
- torch_dtype=MODEL_CONFIG[model_key]["dtype"],
286
- device_map="auto",
287
- low_cpu_mem_usage=True
288
- )
289
- else:
290
- # For CPU-only environments, avoid quantization
291
- model = AutoModelForCausalLM.from_pretrained(
292
- MODEL_CONFIG[model_key]["name"],
293
- torch_dtype=torch.float32, # Use float32 for CPU
294
- low_cpu_mem_usage=True
295
- )
296
-
297
- pipe = pipeline(
298
- "text-generation",
299
- model=model,
300
- tokenizer=tokenizer,
301
- max_new_tokens=64, # Reduced for CPU performance
302
- temperature=0.3,
303
- return_full_text=False,
304
- )
305
-
306
- return HuggingFacePipeline(pipeline=pipe)
307
- except Exception as e:
308
- print(f"Error creating minimal fallback pipeline: {str(e)}")
309
- raise
310
 
311
  def handle_model_loading_error(model_key, session_id):
312
- """Handle model loading errors with fallback options"""
313
- fallback_hierarchy = [
314
  "DeepSeek Coder Instruct", # 1.3B model
315
- "Phi-4", # 1.5B model
316
- "TinyLlama-Chat", # 1.1B model
317
- "Flan-T5-Small" # Paling ringan
318
  ]
319
 
320
- # Jika model yang gagal sudah merupakan fallback terakhir, berikan pesan error
321
- if model_key == fallback_hierarchy[-1]:
322
- return None, f"Tidak dapat memuat model {model_key}. Harap coba lagi nanti."
323
 
324
- # Temukan posisi model yang gagal dalam hirarki
325
- try:
326
- current_index = fallback_hierarchy.index(model_key)
327
- except ValueError:
328
- current_index = -1
329
-
330
- # Coba model berikutnya dalam hirarki
331
- for fallback_model in fallback_hierarchy[current_index+1:]:
332
- try:
333
- print(f"Trying fallback model: {fallback_model}")
334
- chatbot = ChatBot(session_id, fallback_model)
335
- return chatbot, f"Model {model_key} tidak tersedia. Menggunakan {fallback_model} sebagai alternatif."
336
- except Exception as e:
337
- print(f"Fallback model {fallback_model} also failed: {str(e)}")
338
-
339
- return None, "Semua model gagal dimuat. Harap coba lagi nanti."
340
 
341
  def create_conversational_chain(db, file_path, model_key):
342
  llm = create_llm_pipeline(model_key)
@@ -703,18 +611,6 @@ def create_gradio_interface():
703
  import traceback
704
  print(f"Error processing file with {model_key}: {str(e)}")
705
  print(traceback.format_exc())
706
-
707
- # Coba dengan model fallback
708
- try:
709
- chatbot, message = handle_model_loading_error(model_key, sess_id)
710
- if chatbot is not None:
711
- result = chatbot.process_file(file)
712
- return chatbot, True, [(None, message), (None, result)]
713
- else:
714
- return None, False, [(None, message)]
715
- except Exception as fb_err:
716
- error_msg = f"Error dengan model {model_key}: {str(e)}\n\nFallback juga gagal: {str(fb_err)}"
717
- return None, False, [(None, error_msg)]
718
 
719
  process_button.click(
720
  fn=handle_process_file,
@@ -737,21 +633,6 @@ def create_gradio_interface():
737
  outputs=[chatbot_state, model_selected, chatbot_interface, model_dropdown]
738
  )
739
 
740
- # Change model handler
741
- # def handle_model_change(model_key, chatbot, sess_id):
742
- # if chatbot is None:
743
- # chatbot = ChatBot(sess_id, model_key)
744
- # return chatbot, [(None, f"Model diatur ke {model_key}. Silakan upload file CSV.")]
745
-
746
- # result = chatbot.change_model(model_key)
747
- # return chatbot, chatbot.chat_history + [(None, result)]
748
-
749
- # change_model_button.click(
750
- # fn=handle_model_change,
751
- # inputs=[model_dropdown, chatbot_state, session_id],
752
- # outputs=[chatbot_state, chatbot_interface]
753
- # )
754
-
755
  # Chat handlers
756
  def user_message_submitted(message, history, chatbot, sess_id):
757
  history = history + [(message, None)]
 
78
  "description": "Lightweight T5 model optimized for instruction following",
79
  "dtype": torch.float16 if torch.cuda.is_available() else torch.float32,
80
  "is_t5": True
 
 
 
 
 
81
  }
82
  }
83
 
 
 
 
 
 
 
 
84
  def initialize_model_once(model_key):
85
  with MODEL_CACHE["init_lock"]:
86
  current_model = MODEL_CACHE["model_name"]
 
148
 
149
  # Handle standard HF models
150
  else:
151
+ quantization_config = BitsAndBytesConfig(
152
+ load_in_4bit=True,
153
+ bnb_4bit_compute_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
154
+ bnb_4bit_quant_type="nf4",
155
+ bnb_4bit_use_double_quant=True
156
+ )
157
  MODEL_CACHE["tokenizer"] = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
158
+ MODEL_CACHE["model"] = AutoModelForCausalLM.from_pretrained(
159
+ model_name,
160
+ quantization_config=quantization_config,
161
+ torch_dtype=model_info["dtype"],
162
+ device_map="auto" if torch.cuda.is_available() else None,
163
+ low_cpu_mem_usage=True,
164
+ trust_remote_code=True
165
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  MODEL_CACHE["is_gguf"] = False
167
 
168
  print(f"Model {model_name} loaded successfully")
 
180
  print(f"Creating pipeline for model: {model_key}")
181
  tokenizer, model, is_gguf = initialize_model_once(model_key)
182
 
 
 
 
183
  if model is None:
184
  raise ValueError(f"Model is None for {model_key}")
185
 
 
229
  import traceback
230
  print(f"Error creating pipeline: {str(e)}")
231
  print(traceback.format_exc())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
 
233
  def handle_model_loading_error(model_key, session_id):
234
+ """Handle model loading errors by providing alternative model suggestions"""
235
+ suggested_models = [
236
  "DeepSeek Coder Instruct", # 1.3B model
237
+ "Phi-4 Mini Instruct", # Light model
238
+ "TinyLlama Chat", # 1.1B model
239
+ "Flan T5 Small" # Lightweight T5
240
  ]
241
 
242
+ # Remove the current model from suggestions if it's in the list
243
+ if model_key in suggested_models:
244
+ suggested_models.remove(model_key)
245
 
246
+ suggestions = ", ".join(suggested_models[:3]) # Only show top 3 suggestions
247
+ return None, f"Tidak dapat memuat model {model_key}. Silakan coba model lain seperti: {suggestions}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
 
249
  def create_conversational_chain(db, file_path, model_key):
250
  llm = create_llm_pipeline(model_key)
 
611
  import traceback
612
  print(f"Error processing file with {model_key}: {str(e)}")
613
  print(traceback.format_exc())
 
 
 
 
 
 
 
 
 
 
 
 
614
 
615
  process_button.click(
616
  fn=handle_process_file,
 
633
  outputs=[chatbot_state, model_selected, chatbot_interface, model_dropdown]
634
  )
635
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
636
  # Chat handlers
637
  def user_message_submitted(message, history, chatbot, sess_id):
638
  history = history + [(message, None)]