Steph254 commited on
Commit
f2dcdc2
·
verified ·
1 Parent(s): 79ccf40

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -25
app.py CHANGED
@@ -14,38 +14,39 @@ QLORA_ADAPTER = "meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8" # Ensure this
14
  LLAMA_GUARD_NAME = "meta-llama/Llama-Guard-3-1B-INT4" # Ensure this is correct
15
 
16
  # Function to load Llama model
17
- def load_llama_model(model_name, is_guard=False):
18
- print(f"Loading model: {model_name}")
 
19
  try:
20
  # Load tokenizer
21
- tokenizer = AutoTokenizer.from_pretrained(
22
- model_name,
23
- use_fast=False,
24
- token=HUGGINGFACE_TOKEN
25
- )
26
-
27
- # Load model
28
- model = AutoModelForCausalLM.from_pretrained(
29
- model_name,
30
- torch_dtype=torch.float32,
31
- device_map="cpu", # Ensure it runs on CPU
32
- token=HUGGINGFACE_TOKEN
33
- )
34
-
 
 
 
35
  # Load QLoRA adapter if applicable
36
- if not is_guard and "QLORA" in model_name:
37
  print("Loading QLoRA adapter...")
38
- model = PeftModel.from_pretrained(
39
- model,
40
- model_name,
41
- token=HUGGINGFACE_TOKEN
42
- )
43
  print("Merging LoRA weights...")
44
- model = model.merge_and_unload() # Merge LoRA weights for inference
45
-
46
  return tokenizer, model
 
47
  except Exception as e:
48
- print(f"Error loading model {model_name}: {e}")
49
  raise
50
 
51
  # Load Llama 3.2 model
 
14
  LLAMA_GUARD_NAME = "meta-llama/Llama-Guard-3-1B-INT4" # Ensure this is correct
15
 
16
  # Function to load Llama model
17
+ def load_llama_model(model_path, is_guard=False):
18
+ print(f"Loading model: {model_path}")
19
+
20
  try:
21
  # Load tokenizer
22
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, token=HUGGINGFACE_TOKEN)
23
+
24
+ # Load config first (to avoid shape mismatch errors)
25
+ config = AutoModelForCausalLM.from_pretrained(BASE_MODEL, config_only=True).config
26
+
27
+ # 🔹 Manually load the `.pth` file
28
+ state_dict_path = os.path.join(model_path, "consolidated.00.pth")
29
+ if not os.path.exists(state_dict_path):
30
+ raise FileNotFoundError(f"Missing model weights: {state_dict_path}")
31
+
32
+ state_dict = torch.load(state_dict_path, map_location="cpu")
33
+
34
+ # Load model from config and manually apply weights
35
+ model = AutoModelForCausalLM.from_config(config)
36
+ model.load_state_dict(state_dict, strict=False) # Use strict=False to allow missing keys
37
+ model.eval() # Set to inference mode
38
+
39
  # Load QLoRA adapter if applicable
40
+ if not is_guard and "QLORA" in model_path:
41
  print("Loading QLoRA adapter...")
42
+ model = PeftModel.from_pretrained(model, model_path, token=HUGGINGFACE_TOKEN)
 
 
 
 
43
  print("Merging LoRA weights...")
44
+ model = model.merge_and_unload()
45
+
46
  return tokenizer, model
47
+
48
  except Exception as e:
49
+ print(f"Error loading model {model_path}: {e}")
50
  raise
51
 
52
  # Load Llama 3.2 model