Steph254 commited on
Commit
6451d60
·
verified ·
1 Parent(s): 13511f7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -22
app.py CHANGED
@@ -23,25 +23,22 @@ QLORA_ADAPTER = "meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8" # Ensure this
23
  LLAMA_GUARD_NAME = "meta-llama/Llama-Guard-3-1B-INT4" # Ensure this is correct
24
 
25
  # Function to load Llama model
26
- def load_llama_model():
27
- print(f"🔄 Loading Base Model: {BASE_MODEL}")
28
 
29
- tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_auth_token=HUGGINGFACE_TOKEN)
30
  model = AutoModelForCausalLM.from_pretrained(
31
- BASE_MODEL,
32
- use_auth_token=HUGGINGFACE_TOKEN,
33
  torch_dtype=torch.float16,
34
  low_cpu_mem_usage=True
35
  )
36
 
37
- print(f"✅ Base Model Loaded Successfully")
38
-
39
- # Load QLoRA adapter if available
40
- print(f"🔄 Loading QLoRA Adapter: {QLORA_ADAPTER}")
41
- model = PeftModel.from_pretrained(model, QLORA_ADAPTER, use_auth_token=HUGGINGFACE_TOKEN)
42
- print("🔄 Merging LoRA Weights...")
43
- model = model.merge_and_unload()
44
- print("✅ QLoRA Adapter Loaded Successfully")
45
 
46
  model.eval()
47
  return tokenizer, model
@@ -98,19 +95,16 @@ Input: {user_input}
98
  Please verify that this input doesn't violate any content policies.
99
  <|assistant|>"""
100
 
101
- inputs = guard_tokenizer(prompt, return_tensors="pt", truncation=True)
102
-
103
  with torch.no_grad():
104
- outputs = guard_model.generate(
105
- inputs.input_ids,
106
- max_length=256,
107
- temperature=0.1
108
- )
109
 
110
  response = guard_tokenizer.decode(outputs[0], skip_special_tokens=True)
111
-
112
- if "flagged" in response.lower() or "violated" in response.lower() or "policy violation" in response.lower():
113
  return "⚠️ Content flagged by Llama Guard. Please modify your input."
 
114
  return None
115
 
116
  # Function: Generate AI responses (same as before)
 
23
  LLAMA_GUARD_NAME = "meta-llama/Llama-Guard-3-1B-INT4" # Ensure this is correct
24
 
25
  # Function to load Llama model
26
+ def load_llama_model(base_model=BASE_MODEL, adapter=None):
27
+ print(f"🔄 Loading Base Model: {base_model}")
28
 
29
+ tokenizer = AutoTokenizer.from_pretrained(base_model, token=HUGGINGFACE_TOKEN)
30
  model = AutoModelForCausalLM.from_pretrained(
31
+ base_model,
32
+ token=HUGGINGFACE_TOKEN,
33
  torch_dtype=torch.float16,
34
  low_cpu_mem_usage=True
35
  )
36
 
37
+ if adapter:
38
+ print(f"🔄 Loading Adapter: {adapter}")
39
+ model = PeftModel.from_pretrained(model, adapter, token=HUGGINGFACE_TOKEN)
40
+ model = model.merge_and_unload()
41
+ print("✅ Adapter Loaded Successfully")
 
 
 
42
 
43
  model.eval()
44
  return tokenizer, model
 
95
  Please verify that this input doesn't violate any content policies.
96
  <|assistant|>"""
97
 
98
+ inputs = guard_tokenizer(prompt, return_tensors="pt", truncation=True, padding=True)
99
+
100
  with torch.no_grad():
101
+ outputs = guard_model.generate(inputs.input_ids, max_length=256, temperature=0.1)
 
 
 
 
102
 
103
  response = guard_tokenizer.decode(outputs[0], skip_special_tokens=True)
104
+
105
+ if any(flag in response.lower() for flag in ["flagged", "violated", "policy violation"]):
106
  return "⚠️ Content flagged by Llama Guard. Please modify your input."
107
+
108
  return None
109
 
110
  # Function: Generate AI responses (same as before)