Steph254 commited on
Commit
72aeff1
·
verified ·
1 Parent(s): b7a28cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -15
app.py CHANGED
@@ -18,28 +18,44 @@ def load_llama_model(model_path, is_guard=False):
18
  print(f"Loading model: {model_path}")
19
 
20
  try:
21
- # Load tokenizer
22
- tokenizer = LlamaTokenizer.from_pretrained(BASE_MODEL, token=HUGGINGFACE_TOKEN)
 
 
 
 
 
 
 
 
 
23
 
24
  # Load config first (to avoid shape mismatch errors)
25
- config = AutoModelForCausalLM.from_pretrained(BASE_MODEL, config_only=True).config
26
-
27
- # 🔹 Manually load the `.pth` file
28
- state_dict_path = os.path.join(model_path, "consolidated.00.pth")
29
- if not os.path.exists(state_dict_path):
30
- raise FileNotFoundError(f"Missing model weights: {state_dict_path}")
31
-
32
- state_dict = torch.load(state_dict_path, map_location="cpu")
33
-
34
- # Load model from config and manually apply weights
35
- model = AutoModelForCausalLM.from_config(config)
36
- model.load_state_dict(state_dict, strict=False) # Use strict=False to allow missing keys
 
 
 
37
  model.eval() # Set to inference mode
38
 
39
  # Load QLoRA adapter if applicable
40
  if not is_guard and "QLORA" in model_path:
41
  print("Loading QLoRA adapter...")
42
- model = PeftModel.from_pretrained(model, model_path, token=HUGGINGFACE_TOKEN)
 
 
 
 
43
  print("Merging LoRA weights...")
44
  model = model.merge_and_unload()
45
 
 
18
  print(f"Loading model: {model_path}")
19
 
20
  try:
21
+ # Check if token exists and is valid
22
+ token = os.getenv("HUGGINGFACE_TOKEN")
23
+ if not token:
24
+ raise ValueError("HUGGINGFACE_TOKEN environment variable is not set")
25
+
26
+ # Load tokenizer with proper token
27
+ tokenizer = LlamaTokenizer.from_pretrained(
28
+ BASE_MODEL,
29
+ token=token,
30
+ use_fast=False # Sometimes helps with compatibility issues
31
+ )
32
 
33
  # Load config first (to avoid shape mismatch errors)
34
+ config = AutoModelForCausalLM.from_pretrained(
35
+ BASE_MODEL,
36
+ config_only=True,
37
+ token=token
38
+ ).config
39
+
40
+ # Load model from config
41
+ model = AutoModelForCausalLM.from_pretrained(
42
+ model_path,
43
+ token=token,
44
+ config=config,
45
+ device_map="auto", # Better device management
46
+ torch_dtype=torch.float16 # Use half precision for efficiency
47
+ )
48
+
49
  model.eval() # Set to inference mode
50
 
51
  # Load QLoRA adapter if applicable
52
  if not is_guard and "QLORA" in model_path:
53
  print("Loading QLoRA adapter...")
54
+ model = PeftModel.from_pretrained(
55
+ model,
56
+ model_path,
57
+ token=token
58
+ )
59
  print("Merging LoRA weights...")
60
  model = model.merge_and_unload()
61