Update app.py
Browse files
app.py
CHANGED
@@ -18,28 +18,44 @@ def load_llama_model(model_path, is_guard=False):
|
|
18 |
print(f"Loading model: {model_path}")
|
19 |
|
20 |
try:
|
21 |
-
#
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
# Load config first (to avoid shape mismatch errors)
|
25 |
-
config = AutoModelForCausalLM.from_pretrained(
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
37 |
model.eval() # Set to inference mode
|
38 |
|
39 |
# Load QLoRA adapter if applicable
|
40 |
if not is_guard and "QLORA" in model_path:
|
41 |
print("Loading QLoRA adapter...")
|
42 |
-
model = PeftModel.from_pretrained(
|
|
|
|
|
|
|
|
|
43 |
print("Merging LoRA weights...")
|
44 |
model = model.merge_and_unload()
|
45 |
|
|
|
18 |
print(f"Loading model: {model_path}")
|
19 |
|
20 |
try:
|
21 |
+
# Check if token exists and is valid
|
22 |
+
token = os.getenv("HUGGINGFACE_TOKEN")
|
23 |
+
if not token:
|
24 |
+
raise ValueError("HUGGINGFACE_TOKEN environment variable is not set")
|
25 |
+
|
26 |
+
# Load tokenizer with proper token
|
27 |
+
tokenizer = LlamaTokenizer.from_pretrained(
|
28 |
+
BASE_MODEL,
|
29 |
+
token=token,
|
30 |
+
use_fast=False # Sometimes helps with compatibility issues
|
31 |
+
)
|
32 |
|
33 |
# Load config first (to avoid shape mismatch errors)
|
34 |
+
config = AutoModelForCausalLM.from_pretrained(
|
35 |
+
BASE_MODEL,
|
36 |
+
config_only=True,
|
37 |
+
token=token
|
38 |
+
).config
|
39 |
+
|
40 |
+
# Load model from config
|
41 |
+
model = AutoModelForCausalLM.from_pretrained(
|
42 |
+
model_path,
|
43 |
+
token=token,
|
44 |
+
config=config,
|
45 |
+
device_map="auto", # Better device management
|
46 |
+
torch_dtype=torch.float16 # Use half precision for efficiency
|
47 |
+
)
|
48 |
+
|
49 |
model.eval() # Set to inference mode
|
50 |
|
51 |
# Load QLoRA adapter if applicable
|
52 |
if not is_guard and "QLORA" in model_path:
|
53 |
print("Loading QLoRA adapter...")
|
54 |
+
model = PeftModel.from_pretrained(
|
55 |
+
model,
|
56 |
+
model_path,
|
57 |
+
token=token
|
58 |
+
)
|
59 |
print("Merging LoRA weights...")
|
60 |
model = model.merge_and_unload()
|
61 |
|