Update app.py
Browse files
app.py
CHANGED
@@ -1,18 +1,24 @@
|
|
1 |
import gradio as gr
|
2 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
3 |
import torch
|
4 |
import json
|
5 |
from datetime import datetime
|
6 |
|
7 |
-
# Load Llama 3
|
8 |
-
MODEL_NAME = "meta-llama/
|
9 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
10 |
-
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
|
|
11 |
|
12 |
-
# Load Llama Guard for content moderation
|
13 |
LLAMA_GUARD_NAME = "meta-llama/Llama-Guard-3-1B-INT4"
|
14 |
guard_tokenizer = AutoTokenizer.from_pretrained(LLAMA_GUARD_NAME)
|
15 |
-
guard_model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
|
|
16 |
|
17 |
# Define Prompt Templates
|
18 |
PROMPTS = {
|
@@ -50,7 +56,7 @@ def moderate_input(user_input):
|
|
50 |
return "⚠️ Content flagged by Llama Guard. Please modify your input."
|
51 |
return None # Safe input, proceed normally
|
52 |
|
53 |
-
# Function: Generate AI responses
|
54 |
def generate_response(prompt_type, **kwargs):
|
55 |
prompt = PROMPTS[prompt_type].format(**kwargs)
|
56 |
|
@@ -58,11 +64,11 @@ def generate_response(prompt_type, **kwargs):
|
|
58 |
if moderation_warning:
|
59 |
return moderation_warning # Stop processing if flagged
|
60 |
|
61 |
-
inputs = tokenizer(prompt, return_tensors="pt", max_length=
|
62 |
|
63 |
outputs = model.generate(
|
64 |
inputs.input_ids,
|
65 |
-
max_length=
|
66 |
temperature=0.7 if prompt_type == "project_analysis" else 0.5,
|
67 |
top_p=0.9
|
68 |
)
|
@@ -133,7 +139,7 @@ def create_gradio_interface():
|
|
133 |
AI:"""
|
134 |
|
135 |
inputs = tokenizer(prompt, return_tensors="pt")
|
136 |
-
outputs = model.generate(inputs.input_ids, max_length=
|
137 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
138 |
chat_history.append((message, response))
|
139 |
return "", chat_history
|
|
|
1 |
import gradio as gr
|
2 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
3 |
import torch
|
4 |
import json
|
5 |
from datetime import datetime
|
6 |
|
7 |
+
# Load Llama 3.2 (QLoRA) Model on CPU
|
8 |
+
MODEL_NAME = "meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8"
|
9 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
10 |
+
model = AutoModelForCausalLM.from_pretrained(
|
11 |
+
MODEL_NAME,
|
12 |
+
device_map="cpu" # Force CPU usage
|
13 |
+
)
|
14 |
|
15 |
+
# Load Llama Guard for content moderation on CPU
|
16 |
LLAMA_GUARD_NAME = "meta-llama/Llama-Guard-3-1B-INT4"
|
17 |
guard_tokenizer = AutoTokenizer.from_pretrained(LLAMA_GUARD_NAME)
|
18 |
+
guard_model = AutoModelForCausalLM.from_pretrained(
|
19 |
+
LLAMA_GUARD_NAME,
|
20 |
+
device_map="cpu"
|
21 |
+
)
|
22 |
|
23 |
# Define Prompt Templates
|
24 |
PROMPTS = {
|
|
|
56 |
return "⚠️ Content flagged by Llama Guard. Please modify your input."
|
57 |
return None # Safe input, proceed normally
|
58 |
|
59 |
+
# Function: Generate AI responses
|
60 |
def generate_response(prompt_type, **kwargs):
|
61 |
prompt = PROMPTS[prompt_type].format(**kwargs)
|
62 |
|
|
|
64 |
if moderation_warning:
|
65 |
return moderation_warning # Stop processing if flagged
|
66 |
|
67 |
+
inputs = tokenizer(prompt, return_tensors="pt", max_length=1024, truncation=True)
|
68 |
|
69 |
outputs = model.generate(
|
70 |
inputs.input_ids,
|
71 |
+
max_length=1024,
|
72 |
temperature=0.7 if prompt_type == "project_analysis" else 0.5,
|
73 |
top_p=0.9
|
74 |
)
|
|
|
139 |
AI:"""
|
140 |
|
141 |
inputs = tokenizer(prompt, return_tensors="pt")
|
142 |
+
outputs = model.generate(inputs.input_ids, max_length=1024)
|
143 |
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
144 |
chat_history.append((message, response))
|
145 |
return "", chat_history
|